1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/IR/Module.h" 20 #include "llvm/IR/VFABIDemangler.h" 21 #include "llvm/IR/VectorTypeUtils.h" 22 #include "llvm/Support/CheckedArithmetic.h" 23 24 namespace llvm { 25 class TargetLibraryInfo; 26 27 /// The Vector Function Database. 28 /// 29 /// Helper class used to find the vector functions associated to a 30 /// scalar CallInst. 31 class VFDatabase { 32 /// The Module of the CallInst CI. 33 const Module *M; 34 /// The CallInst instance being queried for scalar to vector mappings. 35 const CallInst &CI; 36 /// List of vector functions descriptors associated to the call 37 /// instruction. 38 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 39 40 /// Retrieve the scalar-to-vector mappings associated to the rule of 41 /// a vector Function ABI. 42 static void getVFABIMappings(const CallInst &CI, 43 SmallVectorImpl<VFInfo> &Mappings) { 44 if (!CI.getCalledFunction()) 45 return; 46 47 const StringRef ScalarName = CI.getCalledFunction()->getName(); 48 49 SmallVector<std::string, 8> ListOfStrings; 50 // The check for the vector-function-abi-variant attribute is done when 51 // retrieving the vector variant names here. 52 VFABI::getVectorVariantNames(CI, ListOfStrings); 53 if (ListOfStrings.empty()) 54 return; 55 for (const auto &MangledName : ListOfStrings) { 56 const std::optional<VFInfo> Shape = 57 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType()); 58 // A match is found via scalar and vector names, and also by 59 // ensuring that the variant described in the attribute has a 60 // corresponding definition or declaration of the vector 61 // function in the Module M. 62 if (Shape && (Shape->ScalarName == ScalarName)) { 63 assert(CI.getModule()->getFunction(Shape->VectorName) && 64 "Vector function is missing."); 65 Mappings.push_back(*Shape); 66 } 67 } 68 } 69 70 public: 71 /// Retrieve all the VFInfo instances associated to the CallInst CI. 72 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 73 SmallVector<VFInfo, 8> Ret; 74 75 // Get mappings from the Vector Function ABI variants. 76 getVFABIMappings(CI, Ret); 77 78 // Other non-VFABI variants should be retrieved here. 79 80 return Ret; 81 } 82 83 static bool hasMaskedVariant(const CallInst &CI, 84 std::optional<ElementCount> VF = std::nullopt) { 85 // Check whether we have at least one masked vector version of a scalar 86 // function. If no VF is specified then we check for any masked variant, 87 // otherwise we look for one that matches the supplied VF. 88 auto Mappings = VFDatabase::getMappings(CI); 89 for (VFInfo Info : Mappings) 90 if (!VF || Info.Shape.VF == *VF) 91 if (Info.isMasked()) 92 return true; 93 94 return false; 95 } 96 97 /// Constructor, requires a CallInst instance. 98 VFDatabase(CallInst &CI) 99 : M(CI.getModule()), CI(CI), 100 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 101 102 /// \defgroup VFDatabase query interface. 103 /// 104 /// @{ 105 /// Retrieve the Function with VFShape \p Shape. 106 Function *getVectorizedFunction(const VFShape &Shape) const { 107 if (Shape == VFShape::getScalarShape(CI.getFunctionType())) 108 return CI.getCalledFunction(); 109 110 for (const auto &Info : ScalarToVectorMappings) 111 if (Info.Shape == Shape) 112 return M->getFunction(Info.VectorName); 113 114 return nullptr; 115 } 116 /// @} 117 }; 118 119 template <typename T> class ArrayRef; 120 class DemandedBits; 121 template <typename InstTy> class InterleaveGroup; 122 class IRBuilderBase; 123 class Loop; 124 class TargetTransformInfo; 125 class Value; 126 127 namespace Intrinsic { 128 typedef unsigned ID; 129 } 130 131 /// Identify if the intrinsic is trivially vectorizable. 132 /// This method returns true if the intrinsic's argument types are all scalars 133 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 134 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic. 135 /// 136 /// Note: isTriviallyVectorizable implies isTriviallyScalarizable. 137 bool isTriviallyVectorizable(Intrinsic::ID ID); 138 139 /// Identify if the intrinsic is trivially scalarizable. 140 /// This method returns true following the same predicates of 141 /// isTriviallyVectorizable. 142 143 /// Note: There are intrinsics where implementing vectorization for the 144 /// intrinsic is redundant, but we want to implement scalarization of the 145 /// vector. To prevent the requirement that an intrinsic also implements 146 /// vectorization we provide this seperate function. 147 bool isTriviallyScalarizable(Intrinsic::ID ID, const TargetTransformInfo *TTI); 148 149 /// Identifies if the vector form of the intrinsic has a scalar operand. 150 /// \p TTI is used to consider target specific intrinsics, if no target specific 151 /// intrinsics will be considered then it is appropriate to pass in nullptr. 152 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx, 153 const TargetTransformInfo *TTI); 154 155 /// Identifies if the vector form of the intrinsic is overloaded on the type of 156 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. 157 /// \p TTI is used to consider target specific intrinsics, if no target specific 158 /// intrinsics will be considered then it is appropriate to pass in nullptr. 159 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx, 160 const TargetTransformInfo *TTI); 161 162 /// Identifies if the vector form of the intrinsic that returns a struct is 163 /// overloaded at the struct element index \p RetIdx. /// \p TTI is used to 164 /// consider target specific intrinsics, if no target specific intrinsics 165 /// will be considered then it is appropriate to pass in nullptr. 166 bool isVectorIntrinsicWithStructReturnOverloadAtField( 167 Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI); 168 169 /// Returns intrinsic ID for call. 170 /// For the input call instruction it finds mapping intrinsic and returns 171 /// its intrinsic ID, in case it does not found it return not_intrinsic. 172 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 173 const TargetLibraryInfo *TLI); 174 175 /// Given a vector and an element number, see if the scalar value is 176 /// already around as a register, for example if it were inserted then extracted 177 /// from the vector. 178 Value *findScalarElement(Value *V, unsigned EltNo); 179 180 /// If all non-negative \p Mask elements are the same value, return that value. 181 /// If all elements are negative (undefined) or \p Mask contains different 182 /// non-negative values, return -1. 183 int getSplatIndex(ArrayRef<int> Mask); 184 185 /// Get splat value if the input is a splat vector or return nullptr. 186 /// The value may be extracted from a splat constants vector or from 187 /// a sequence of instructions that broadcast a single value into a vector. 188 Value *getSplatValue(const Value *V); 189 190 /// Return true if each element of the vector value \p V is poisoned or equal to 191 /// every other non-poisoned element. If an index element is specified, either 192 /// every element of the vector is poisoned or the element at that index is not 193 /// poisoned and equal to every other non-poisoned element. 194 /// This may be more powerful than the related getSplatValue() because it is 195 /// not limited by finding a scalar source value to a splatted vector. 196 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 197 198 /// Transform a shuffle mask's output demanded element mask into demanded 199 /// element masks for the 2 operands, returns false if the mask isn't valid. 200 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. 201 /// \p AllowUndefElts permits "-1" indices to be treated as undef. 202 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask, 203 const APInt &DemandedElts, APInt &DemandedLHS, 204 APInt &DemandedRHS, bool AllowUndefElts = false); 205 206 /// Replace each shuffle mask index with the scaled sequential indices for an 207 /// equivalent mask of narrowed elements. Mask elements that are less than 0 208 /// (sentinel values) are repeated in the output mask. 209 /// 210 /// Example with Scale = 4: 211 /// <4 x i32> <3, 2, 0, -1> --> 212 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 213 /// 214 /// This is the reverse process of widening shuffle mask elements, but it always 215 /// succeeds because the indexes can always be multiplied (scaled up) to map to 216 /// narrower vector elements. 217 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 218 SmallVectorImpl<int> &ScaledMask); 219 220 /// Try to transform a shuffle mask by replacing elements with the scaled index 221 /// for an equivalent mask of widened elements. If all mask elements that would 222 /// map to a wider element of the new mask are the same negative number 223 /// (sentinel value), that element of the new mask is the same value. If any 224 /// element in a given slice is negative and some other element in that slice is 225 /// not the same value, return false (partial matches with sentinel values are 226 /// not allowed). 227 /// 228 /// Example with Scale = 4: 229 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 230 /// <4 x i32> <3, 2, 0, -1> 231 /// 232 /// This is the reverse process of narrowing shuffle mask elements if it 233 /// succeeds. This transform is not always possible because indexes may not 234 /// divide evenly (scale down) to map to wider vector elements. 235 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 236 SmallVectorImpl<int> &ScaledMask); 237 238 /// A variant of the previous method which is specialized for Scale=2, and 239 /// treats -1 as undef and allows widening when a wider element is partially 240 /// undef in the narrow form of the mask. This transformation discards 241 /// information about which bytes in the original shuffle were undef. 242 bool widenShuffleMaskElts(ArrayRef<int> M, SmallVectorImpl<int> &NewMask); 243 244 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target 245 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts. 246 /// This will assert unless NumDstElts is a multiple of Mask.size (or 247 /// vice-versa). Returns false on failure, and ScaledMask will be in an 248 /// undefined state. 249 bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask, 250 SmallVectorImpl<int> &ScaledMask); 251 252 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds, 253 /// to get the shuffle mask with widest possible elements. 254 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask, 255 SmallVectorImpl<int> &ScaledMask); 256 257 /// Splits and processes shuffle mask depending on the number of input and 258 /// output registers. The function does 2 main things: 1) splits the 259 /// source/destination vectors into real registers; 2) do the mask analysis to 260 /// identify which real registers are permuted. Then the function processes 261 /// resulting registers mask using provided action items. If no input register 262 /// is defined, \p NoInputAction action is used. If only 1 input register is 263 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to 264 /// process > 2 input registers and masks. 265 /// \param Mask Original shuffle mask. 266 /// \param NumOfSrcRegs Number of source registers. 267 /// \param NumOfDestRegs Number of destination registers. 268 /// \param NumOfUsedRegs Number of actually used destination registers. 269 void processShuffleMasks( 270 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, 271 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, 272 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, 273 function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)> 274 ManyInputsAction); 275 276 /// Compute the demanded elements mask of horizontal binary operations. A 277 /// horizontal operation combines two adjacent elements in a vector operand. 278 /// This function returns a mask for the elements that correspond to the first 279 /// operand of this horizontal combination. For example, for two vectors 280 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the 281 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the 282 /// result of this function to the left by 1. 283 /// 284 /// \param VectorBitWidth the total bit width of the vector 285 /// \param DemandedElts the demanded elements mask for the operation 286 /// \param DemandedLHS the demanded elements mask for the left operand 287 /// \param DemandedRHS the demanded elements mask for the right operand 288 void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth, 289 const APInt &DemandedElts, 290 APInt &DemandedLHS, 291 APInt &DemandedRHS); 292 293 /// Compute a map of integer instructions to their minimum legal type 294 /// size. 295 /// 296 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 297 /// type (e.g. i32) whenever arithmetic is performed on them. 298 /// 299 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 300 /// the arithmetic type down again. However InstCombine refuses to create 301 /// illegal types, so for targets without i8 or i16 registers, the lengthening 302 /// and shrinking remains. 303 /// 304 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 305 /// their scalar equivalents do not, so during vectorization it is important to 306 /// remove these lengthens and truncates when deciding the profitability of 307 /// vectorization. 308 /// 309 /// This function analyzes the given range of instructions and determines the 310 /// minimum type size each can be converted to. It attempts to remove or 311 /// minimize type size changes across each def-use chain, so for example in the 312 /// following code: 313 /// 314 /// %1 = load i8, i8* 315 /// %2 = add i8 %1, 2 316 /// %3 = load i16, i16* 317 /// %4 = zext i8 %2 to i32 318 /// %5 = zext i16 %3 to i32 319 /// %6 = add i32 %4, %5 320 /// %7 = trunc i32 %6 to i16 321 /// 322 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 323 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 324 /// 325 /// If the optional TargetTransformInfo is provided, this function tries harder 326 /// to do less work by only looking at illegal types. 327 MapVector<Instruction*, uint64_t> 328 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 329 DemandedBits &DB, 330 const TargetTransformInfo *TTI=nullptr); 331 332 /// Compute the union of two access-group lists. 333 /// 334 /// If the list contains just one access group, it is returned directly. If the 335 /// list is empty, returns nullptr. 336 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 337 338 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 339 /// are both in. If either instruction does not access memory at all, it is 340 /// considered to be in every list. 341 /// 342 /// If the list contains just one access group, it is returned directly. If the 343 /// list is empty, returns nullptr. 344 MDNode *intersectAccessGroups(const Instruction *Inst1, 345 const Instruction *Inst2); 346 347 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 348 /// MD_nontemporal, MD_access_group, MD_mmra]. 349 /// For K in Kinds, we get the MDNode for K from each of the 350 /// elements of VL, compute their "intersection" (i.e., the most generic 351 /// metadata value that covers all of the individual values), and set I's 352 /// metadata for M equal to the intersection value. 353 /// 354 /// This function always sets a (possibly null) value for each K in Kinds. 355 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 356 357 /// Create a mask that filters the members of an interleave group where there 358 /// are gaps. 359 /// 360 /// For example, the mask for \p Group with interleave-factor 3 361 /// and \p VF 4, that has only its first member present is: 362 /// 363 /// <1,0,0,1,0,0,1,0,0,1,0,0> 364 /// 365 /// Note: The result is a mask of 0's and 1's, as opposed to the other 366 /// create[*]Mask() utilities which create a shuffle mask (mask that 367 /// consists of indices). 368 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 369 const InterleaveGroup<Instruction> &Group); 370 371 /// Create a mask with replicated elements. 372 /// 373 /// This function creates a shuffle mask for replicating each of the \p VF 374 /// elements in a vector \p ReplicationFactor times. It can be used to 375 /// transform a mask of \p VF elements into a mask of 376 /// \p VF * \p ReplicationFactor elements used by a predicated 377 /// interleaved-group of loads/stores whose Interleaved-factor == 378 /// \p ReplicationFactor. 379 /// 380 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 381 /// 382 /// <0,0,0,1,1,1,2,2,2,3,3,3> 383 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 384 unsigned VF); 385 386 /// Create an interleave shuffle mask. 387 /// 388 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 389 /// vectorization factor \p VF into a single wide vector. The mask is of the 390 /// form: 391 /// 392 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 393 /// 394 /// For example, the mask for VF = 4 and NumVecs = 2 is: 395 /// 396 /// <0, 4, 1, 5, 2, 6, 3, 7>. 397 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 398 399 /// Create a stride shuffle mask. 400 /// 401 /// This function creates a shuffle mask whose elements begin at \p Start and 402 /// are incremented by \p Stride. The mask can be used to deinterleave an 403 /// interleaved vector into separate vectors of vectorization factor \p VF. The 404 /// mask is of the form: 405 /// 406 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 407 /// 408 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 409 /// 410 /// <0, 2, 4, 6> 411 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 412 unsigned VF); 413 414 /// Create a sequential shuffle mask. 415 /// 416 /// This function creates shuffle mask whose elements are sequential and begin 417 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 418 /// NumUndefs undef values. The mask is of the form: 419 /// 420 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 421 /// 422 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 423 /// 424 /// <0, 1, 2, 3, undef, undef, undef, undef> 425 llvm::SmallVector<int, 16> 426 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 427 428 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle 429 /// mask assuming both operands are identical. This assumes that the unary 430 /// shuffle will use elements from operand 0 (operand 1 will be unused). 431 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask, 432 unsigned NumElts); 433 434 /// Concatenate a list of vectors. 435 /// 436 /// This function generates code that concatenate the vectors in \p Vecs into a 437 /// single large vector. The number of vectors should be greater than one, and 438 /// their element types should be the same. The number of elements in the 439 /// vectors should also be the same; however, if the last vector has fewer 440 /// elements, it will be padded with undefs. 441 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 442 443 /// Given a mask vector of i1, Return true if all of the elements of this 444 /// predicate mask are known to be false or undef. That is, return true if all 445 /// lanes can be assumed inactive. 446 bool maskIsAllZeroOrUndef(Value *Mask); 447 448 /// Given a mask vector of i1, Return true if all of the elements of this 449 /// predicate mask are known to be true or undef. That is, return true if all 450 /// lanes can be assumed active. 451 bool maskIsAllOneOrUndef(Value *Mask); 452 453 /// Given a mask vector of i1, Return true if any of the elements of this 454 /// predicate mask are known to be true or undef. That is, return true if at 455 /// least one lane can be assumed active. 456 bool maskContainsAllOneOrUndef(Value *Mask); 457 458 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 459 /// for each lane which may be active. 460 APInt possiblyDemandedEltsInMask(Value *Mask); 461 462 /// The group of interleaved loads/stores sharing the same stride and 463 /// close to each other. 464 /// 465 /// Each member in this group has an index starting from 0, and the largest 466 /// index should be less than interleaved factor, which is equal to the absolute 467 /// value of the access's stride. 468 /// 469 /// E.g. An interleaved load group of factor 4: 470 /// for (unsigned i = 0; i < 1024; i+=4) { 471 /// a = A[i]; // Member of index 0 472 /// b = A[i+1]; // Member of index 1 473 /// d = A[i+3]; // Member of index 3 474 /// ... 475 /// } 476 /// 477 /// An interleaved store group of factor 4: 478 /// for (unsigned i = 0; i < 1024; i+=4) { 479 /// ... 480 /// A[i] = a; // Member of index 0 481 /// A[i+1] = b; // Member of index 1 482 /// A[i+2] = c; // Member of index 2 483 /// A[i+3] = d; // Member of index 3 484 /// } 485 /// 486 /// Note: the interleaved load group could have gaps (missing members), but 487 /// the interleaved store group doesn't allow gaps. 488 template <typename InstTy> class InterleaveGroup { 489 public: 490 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 491 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 492 InsertPos(nullptr) {} 493 494 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 495 : Alignment(Alignment), InsertPos(Instr) { 496 Factor = std::abs(Stride); 497 assert(Factor > 1 && "Invalid interleave factor"); 498 499 Reverse = Stride < 0; 500 Members[0] = Instr; 501 } 502 503 bool isReverse() const { return Reverse; } 504 uint32_t getFactor() const { return Factor; } 505 Align getAlign() const { return Alignment; } 506 uint32_t getNumMembers() const { return Members.size(); } 507 508 /// Try to insert a new member \p Instr with index \p Index and 509 /// alignment \p NewAlign. The index is related to the leader and it could be 510 /// negative if it is the new leader. 511 /// 512 /// \returns false if the instruction doesn't belong to the group. 513 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 514 // Make sure the key fits in an int32_t. 515 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 516 if (!MaybeKey) 517 return false; 518 int32_t Key = *MaybeKey; 519 520 // Skip if the key is used for either the tombstone or empty special values. 521 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 522 DenseMapInfo<int32_t>::getEmptyKey() == Key) 523 return false; 524 525 // Skip if there is already a member with the same index. 526 if (Members.contains(Key)) 527 return false; 528 529 if (Key > LargestKey) { 530 // The largest index is always less than the interleave factor. 531 if (Index >= static_cast<int32_t>(Factor)) 532 return false; 533 534 LargestKey = Key; 535 } else if (Key < SmallestKey) { 536 537 // Make sure the largest index fits in an int32_t. 538 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 539 if (!MaybeLargestIndex) 540 return false; 541 542 // The largest index is always less than the interleave factor. 543 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 544 return false; 545 546 SmallestKey = Key; 547 } 548 549 // It's always safe to select the minimum alignment. 550 Alignment = std::min(Alignment, NewAlign); 551 Members[Key] = Instr; 552 return true; 553 } 554 555 /// Get the member with the given index \p Index 556 /// 557 /// \returns nullptr if contains no such member. 558 InstTy *getMember(uint32_t Index) const { 559 int32_t Key = SmallestKey + Index; 560 return Members.lookup(Key); 561 } 562 563 /// Get the index for the given member. Unlike the key in the member 564 /// map, the index starts from 0. 565 uint32_t getIndex(const InstTy *Instr) const { 566 for (auto I : Members) { 567 if (I.second == Instr) 568 return I.first - SmallestKey; 569 } 570 571 llvm_unreachable("InterleaveGroup contains no such member"); 572 } 573 574 InstTy *getInsertPos() const { return InsertPos; } 575 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 576 577 /// Add metadata (e.g. alias info) from the instructions in this group to \p 578 /// NewInst. 579 /// 580 /// FIXME: this function currently does not add noalias metadata a'la 581 /// addNewMedata. To do that we need to compute the intersection of the 582 /// noalias info from all members. 583 void addMetadata(InstTy *NewInst) const; 584 585 /// Returns true if this Group requires a scalar iteration to handle gaps. 586 bool requiresScalarEpilogue() const { 587 // If the last member of the Group exists, then a scalar epilog is not 588 // needed for this group. 589 if (getMember(getFactor() - 1)) 590 return false; 591 592 // We have a group with gaps. It therefore can't be a reversed access, 593 // because such groups get invalidated (TODO). 594 assert(!isReverse() && "Group should have been invalidated"); 595 596 // This is a group of loads, with gaps, and without a last-member 597 return true; 598 } 599 600 private: 601 uint32_t Factor; // Interleave Factor. 602 bool Reverse; 603 Align Alignment; 604 DenseMap<int32_t, InstTy *> Members; 605 int32_t SmallestKey = 0; 606 int32_t LargestKey = 0; 607 608 // To avoid breaking dependences, vectorized instructions of an interleave 609 // group should be inserted at either the first load or the last store in 610 // program order. 611 // 612 // E.g. %even = load i32 // Insert Position 613 // %add = add i32 %even // Use of %even 614 // %odd = load i32 615 // 616 // store i32 %even 617 // %odd = add i32 // Def of %odd 618 // store i32 %odd // Insert Position 619 InstTy *InsertPos; 620 }; 621 622 /// Drive the analysis of interleaved memory accesses in the loop. 623 /// 624 /// Use this class to analyze interleaved accesses only when we can vectorize 625 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 626 /// on interleaved accesses is unsafe. 627 /// 628 /// The analysis collects interleave groups and records the relationships 629 /// between the member and the group in a map. 630 class InterleavedAccessInfo { 631 public: 632 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 633 DominatorTree *DT, LoopInfo *LI, 634 const LoopAccessInfo *LAI) 635 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 636 637 ~InterleavedAccessInfo() { invalidateGroups(); } 638 639 /// Analyze the interleaved accesses and collect them in interleave 640 /// groups. Substitute symbolic strides using \p Strides. 641 /// Consider also predicated loads/stores in the analysis if 642 /// \p EnableMaskedInterleavedGroup is true. 643 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 644 645 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 646 /// contrary to original assumption. Although we currently prevent group 647 /// formation for predicated accesses, we may be able to relax this limitation 648 /// in the future once we handle more complicated blocks. Returns true if any 649 /// groups were invalidated. 650 bool invalidateGroups() { 651 if (InterleaveGroups.empty()) { 652 assert( 653 !RequiresScalarEpilogue && 654 "RequiresScalarEpilog should not be set without interleave groups"); 655 return false; 656 } 657 658 InterleaveGroupMap.clear(); 659 for (auto *Ptr : InterleaveGroups) 660 delete Ptr; 661 InterleaveGroups.clear(); 662 RequiresScalarEpilogue = false; 663 return true; 664 } 665 666 /// Check if \p Instr belongs to any interleave group. 667 bool isInterleaved(Instruction *Instr) const { 668 return InterleaveGroupMap.contains(Instr); 669 } 670 671 /// Get the interleave group that \p Instr belongs to. 672 /// 673 /// \returns nullptr if doesn't have such group. 674 InterleaveGroup<Instruction> * 675 getInterleaveGroup(const Instruction *Instr) const { 676 return InterleaveGroupMap.lookup(Instr); 677 } 678 679 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> 680 getInterleaveGroups() { 681 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 682 } 683 684 /// Returns true if an interleaved group that may access memory 685 /// out-of-bounds requires a scalar epilogue iteration for correctness. 686 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 687 688 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 689 /// happen when optimizing for size forbids a scalar epilogue, and the gap 690 /// cannot be filtered by masking the load/store. 691 void invalidateGroupsRequiringScalarEpilogue(); 692 693 /// Returns true if we have any interleave groups. 694 bool hasGroups() const { return !InterleaveGroups.empty(); } 695 696 private: 697 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 698 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 699 /// The interleaved access analysis can also add new predicates (for example 700 /// by versioning strides of pointers). 701 PredicatedScalarEvolution &PSE; 702 703 Loop *TheLoop; 704 DominatorTree *DT; 705 LoopInfo *LI; 706 const LoopAccessInfo *LAI; 707 708 /// True if the loop may contain non-reversed interleaved groups with 709 /// out-of-bounds accesses. We ensure we don't speculatively access memory 710 /// out-of-bounds by executing at least one scalar epilogue iteration. 711 bool RequiresScalarEpilogue = false; 712 713 /// Holds the relationships between the members and the interleave group. 714 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 715 716 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 717 718 /// Holds dependences among the memory accesses in the loop. It maps a source 719 /// access to a set of dependent sink accesses. 720 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 721 722 /// The descriptor for a strided memory access. 723 struct StrideDescriptor { 724 StrideDescriptor() = default; 725 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 726 Align Alignment) 727 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 728 729 // The access's stride. It is negative for a reverse access. 730 int64_t Stride = 0; 731 732 // The scalar expression of this access. 733 const SCEV *Scev = nullptr; 734 735 // The size of the memory object. 736 uint64_t Size = 0; 737 738 // The alignment of this access. 739 Align Alignment; 740 }; 741 742 /// A type for holding instructions and their stride descriptors. 743 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 744 745 /// Create a new interleave group with the given instruction \p Instr, 746 /// stride \p Stride and alignment \p Align. 747 /// 748 /// \returns the newly created interleave group. 749 InterleaveGroup<Instruction> * 750 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 751 assert(!InterleaveGroupMap.count(Instr) && 752 "Already in an interleaved access group"); 753 InterleaveGroupMap[Instr] = 754 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 755 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 756 return InterleaveGroupMap[Instr]; 757 } 758 759 /// Release the group and remove all the relationships. 760 void releaseGroup(InterleaveGroup<Instruction> *Group) { 761 InterleaveGroups.erase(Group); 762 releaseGroupWithoutRemovingFromSet(Group); 763 } 764 765 /// Do everything necessary to release the group, apart from removing it from 766 /// the InterleaveGroups set. 767 void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) { 768 for (unsigned i = 0; i < Group->getFactor(); i++) 769 if (Instruction *Member = Group->getMember(i)) 770 InterleaveGroupMap.erase(Member); 771 772 delete Group; 773 } 774 775 /// Collect all the accesses with a constant stride in program order. 776 void collectConstStrideAccesses( 777 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 778 const DenseMap<Value *, const SCEV *> &Strides); 779 780 /// Returns true if \p Stride is allowed in an interleaved group. 781 static bool isStrided(int Stride); 782 783 /// Returns true if \p BB is a predicated block. 784 bool isPredicated(BasicBlock *BB) const { 785 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 786 } 787 788 /// Returns true if LoopAccessInfo can be used for dependence queries. 789 bool areDependencesValid() const { 790 return LAI && LAI->getDepChecker().getDependences(); 791 } 792 793 /// Returns true if memory accesses \p A and \p B can be reordered, if 794 /// necessary, when constructing interleaved groups. 795 /// 796 /// \p A must precede \p B in program order. We return false if reordering is 797 /// not necessary or is prevented because \p A and \p B may be dependent. 798 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 799 StrideEntry *B) const { 800 // Code motion for interleaved accesses can potentially hoist strided loads 801 // and sink strided stores. The code below checks the legality of the 802 // following two conditions: 803 // 804 // 1. Potentially moving a strided load (B) before any store (A) that 805 // precedes B, or 806 // 807 // 2. Potentially moving a strided store (A) after any load or store (B) 808 // that A precedes. 809 // 810 // It's legal to reorder A and B if we know there isn't a dependence from A 811 // to B. Note that this determination is conservative since some 812 // dependences could potentially be reordered safely. 813 814 // A is potentially the source of a dependence. 815 auto *Src = A->first; 816 auto SrcDes = A->second; 817 818 // B is potentially the sink of a dependence. 819 auto *Sink = B->first; 820 auto SinkDes = B->second; 821 822 // Code motion for interleaved accesses can't violate WAR dependences. 823 // Thus, reordering is legal if the source isn't a write. 824 if (!Src->mayWriteToMemory()) 825 return true; 826 827 // At least one of the accesses must be strided. 828 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 829 return true; 830 831 // If dependence information is not available from LoopAccessInfo, 832 // conservatively assume the instructions can't be reordered. 833 if (!areDependencesValid()) 834 return false; 835 836 // If we know there is a dependence from source to sink, assume the 837 // instructions can't be reordered. Otherwise, reordering is legal. 838 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink); 839 } 840 841 /// Collect the dependences from LoopAccessInfo. 842 /// 843 /// We process the dependences once during the interleaved access analysis to 844 /// enable constant-time dependence queries. 845 void collectDependences() { 846 if (!areDependencesValid()) 847 return; 848 const auto &DepChecker = LAI->getDepChecker(); 849 auto *Deps = DepChecker.getDependences(); 850 for (auto Dep : *Deps) 851 Dependences[Dep.getSource(DepChecker)].insert( 852 Dep.getDestination(DepChecker)); 853 } 854 }; 855 856 } // llvm namespace 857 858 #endif 859