xref: /llvm-project/llvm/include/llvm/Analysis/VectorUtils.h (revision bab7920fd7ea822543b8f1aa8037d489eea2cb73)
1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15 
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/VFABIDemangler.h"
21 #include "llvm/IR/VectorTypeUtils.h"
22 #include "llvm/Support/CheckedArithmetic.h"
23 
24 namespace llvm {
25 class TargetLibraryInfo;
26 
27 /// The Vector Function Database.
28 ///
29 /// Helper class used to find the vector functions associated to a
30 /// scalar CallInst.
31 class VFDatabase {
32   /// The Module of the CallInst CI.
33   const Module *M;
34   /// The CallInst instance being queried for scalar to vector mappings.
35   const CallInst &CI;
36   /// List of vector functions descriptors associated to the call
37   /// instruction.
38   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
39 
40   /// Retrieve the scalar-to-vector mappings associated to the rule of
41   /// a vector Function ABI.
42   static void getVFABIMappings(const CallInst &CI,
43                                SmallVectorImpl<VFInfo> &Mappings) {
44     if (!CI.getCalledFunction())
45       return;
46 
47     const StringRef ScalarName = CI.getCalledFunction()->getName();
48 
49     SmallVector<std::string, 8> ListOfStrings;
50     // The check for the vector-function-abi-variant attribute is done when
51     // retrieving the vector variant names here.
52     VFABI::getVectorVariantNames(CI, ListOfStrings);
53     if (ListOfStrings.empty())
54       return;
55     for (const auto &MangledName : ListOfStrings) {
56       const std::optional<VFInfo> Shape =
57           VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
58       // A match is found via scalar and vector names, and also by
59       // ensuring that the variant described in the attribute has a
60       // corresponding definition or declaration of the vector
61       // function in the Module M.
62       if (Shape && (Shape->ScalarName == ScalarName)) {
63         assert(CI.getModule()->getFunction(Shape->VectorName) &&
64                "Vector function is missing.");
65         Mappings.push_back(*Shape);
66       }
67     }
68   }
69 
70 public:
71   /// Retrieve all the VFInfo instances associated to the CallInst CI.
72   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
73     SmallVector<VFInfo, 8> Ret;
74 
75     // Get mappings from the Vector Function ABI variants.
76     getVFABIMappings(CI, Ret);
77 
78     // Other non-VFABI variants should be retrieved here.
79 
80     return Ret;
81   }
82 
83   static bool hasMaskedVariant(const CallInst &CI,
84                                std::optional<ElementCount> VF = std::nullopt) {
85     // Check whether we have at least one masked vector version of a scalar
86     // function. If no VF is specified then we check for any masked variant,
87     // otherwise we look for one that matches the supplied VF.
88     auto Mappings = VFDatabase::getMappings(CI);
89     for (VFInfo Info : Mappings)
90       if (!VF || Info.Shape.VF == *VF)
91         if (Info.isMasked())
92           return true;
93 
94     return false;
95   }
96 
97   /// Constructor, requires a CallInst instance.
98   VFDatabase(CallInst &CI)
99       : M(CI.getModule()), CI(CI),
100         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
101 
102   /// \defgroup VFDatabase query interface.
103   ///
104   /// @{
105   /// Retrieve the Function with VFShape \p Shape.
106   Function *getVectorizedFunction(const VFShape &Shape) const {
107     if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
108       return CI.getCalledFunction();
109 
110     for (const auto &Info : ScalarToVectorMappings)
111       if (Info.Shape == Shape)
112         return M->getFunction(Info.VectorName);
113 
114     return nullptr;
115   }
116   /// @}
117 };
118 
119 template <typename T> class ArrayRef;
120 class DemandedBits;
121 template <typename InstTy> class InterleaveGroup;
122 class IRBuilderBase;
123 class Loop;
124 class TargetTransformInfo;
125 class Value;
126 
127 namespace Intrinsic {
128 typedef unsigned ID;
129 }
130 
131 /// Identify if the intrinsic is trivially vectorizable.
132 /// This method returns true if the intrinsic's argument types are all scalars
133 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
134 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
135 ///
136 /// Note: isTriviallyVectorizable implies isTriviallyScalarizable.
137 bool isTriviallyVectorizable(Intrinsic::ID ID);
138 
139 /// Identify if the intrinsic is trivially scalarizable.
140 /// This method returns true following the same predicates of
141 /// isTriviallyVectorizable.
142 
143 /// Note: There are intrinsics where implementing vectorization for the
144 /// intrinsic is redundant, but we want to implement scalarization of the
145 /// vector. To prevent the requirement that an intrinsic also implements
146 /// vectorization we provide this seperate function.
147 bool isTriviallyScalarizable(Intrinsic::ID ID, const TargetTransformInfo *TTI);
148 
149 /// Identifies if the vector form of the intrinsic has a scalar operand.
150 /// \p TTI is used to consider target specific intrinsics, if no target specific
151 /// intrinsics will be considered then it is appropriate to pass in nullptr.
152 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
153                                         const TargetTransformInfo *TTI);
154 
155 /// Identifies if the vector form of the intrinsic is overloaded on the type of
156 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
157 /// \p TTI is used to consider target specific intrinsics, if no target specific
158 /// intrinsics will be considered then it is appropriate to pass in nullptr.
159 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
160                                             const TargetTransformInfo *TTI);
161 
162 /// Identifies if the vector form of the intrinsic that returns a struct is
163 /// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
164 /// consider target specific intrinsics, if no target specific intrinsics
165 /// will be considered then it is appropriate to pass in nullptr.
166 bool isVectorIntrinsicWithStructReturnOverloadAtField(
167     Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
168 
169 /// Returns intrinsic ID for call.
170 /// For the input call instruction it finds mapping intrinsic and returns
171 /// its intrinsic ID, in case it does not found it return not_intrinsic.
172 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
173                                           const TargetLibraryInfo *TLI);
174 
175 /// Given a vector and an element number, see if the scalar value is
176 /// already around as a register, for example if it were inserted then extracted
177 /// from the vector.
178 Value *findScalarElement(Value *V, unsigned EltNo);
179 
180 /// If all non-negative \p Mask elements are the same value, return that value.
181 /// If all elements are negative (undefined) or \p Mask contains different
182 /// non-negative values, return -1.
183 int getSplatIndex(ArrayRef<int> Mask);
184 
185 /// Get splat value if the input is a splat vector or return nullptr.
186 /// The value may be extracted from a splat constants vector or from
187 /// a sequence of instructions that broadcast a single value into a vector.
188 Value *getSplatValue(const Value *V);
189 
190 /// Return true if each element of the vector value \p V is poisoned or equal to
191 /// every other non-poisoned element. If an index element is specified, either
192 /// every element of the vector is poisoned or the element at that index is not
193 /// poisoned and equal to every other non-poisoned element.
194 /// This may be more powerful than the related getSplatValue() because it is
195 /// not limited by finding a scalar source value to a splatted vector.
196 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
197 
198 /// Transform a shuffle mask's output demanded element mask into demanded
199 /// element masks for the 2 operands, returns false if the mask isn't valid.
200 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
201 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
202 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
203                             const APInt &DemandedElts, APInt &DemandedLHS,
204                             APInt &DemandedRHS, bool AllowUndefElts = false);
205 
206 /// Replace each shuffle mask index with the scaled sequential indices for an
207 /// equivalent mask of narrowed elements. Mask elements that are less than 0
208 /// (sentinel values) are repeated in the output mask.
209 ///
210 /// Example with Scale = 4:
211 ///   <4 x i32> <3, 2, 0, -1> -->
212 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
213 ///
214 /// This is the reverse process of widening shuffle mask elements, but it always
215 /// succeeds because the indexes can always be multiplied (scaled up) to map to
216 /// narrower vector elements.
217 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
218                            SmallVectorImpl<int> &ScaledMask);
219 
220 /// Try to transform a shuffle mask by replacing elements with the scaled index
221 /// for an equivalent mask of widened elements. If all mask elements that would
222 /// map to a wider element of the new mask are the same negative number
223 /// (sentinel value), that element of the new mask is the same value. If any
224 /// element in a given slice is negative and some other element in that slice is
225 /// not the same value, return false (partial matches with sentinel values are
226 /// not allowed).
227 ///
228 /// Example with Scale = 4:
229 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
230 ///   <4 x i32> <3, 2, 0, -1>
231 ///
232 /// This is the reverse process of narrowing shuffle mask elements if it
233 /// succeeds. This transform is not always possible because indexes may not
234 /// divide evenly (scale down) to map to wider vector elements.
235 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
236                           SmallVectorImpl<int> &ScaledMask);
237 
238 /// A variant of the previous method which is specialized for Scale=2, and
239 /// treats -1 as undef and allows widening when a wider element is partially
240 /// undef in the narrow form of the mask.  This transformation discards
241 /// information about which bytes in the original shuffle were undef.
242 bool widenShuffleMaskElts(ArrayRef<int> M, SmallVectorImpl<int> &NewMask);
243 
244 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target
245 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts.
246 /// This will assert unless NumDstElts is a multiple of Mask.size (or
247 /// vice-versa). Returns false on failure, and ScaledMask will be in an
248 /// undefined state.
249 bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask,
250                           SmallVectorImpl<int> &ScaledMask);
251 
252 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
253 /// to get the shuffle mask with widest possible elements.
254 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
255                                   SmallVectorImpl<int> &ScaledMask);
256 
257 /// Splits and processes shuffle mask depending on the number of input and
258 /// output registers. The function does 2 main things: 1) splits the
259 /// source/destination vectors into real registers; 2) do the mask analysis to
260 /// identify which real registers are permuted. Then the function processes
261 /// resulting registers mask using provided action items. If no input register
262 /// is defined, \p NoInputAction action is used. If only 1 input register is
263 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
264 /// process > 2 input registers and masks.
265 /// \param Mask Original shuffle mask.
266 /// \param NumOfSrcRegs Number of source registers.
267 /// \param NumOfDestRegs Number of destination registers.
268 /// \param NumOfUsedRegs Number of actually used destination registers.
269 void processShuffleMasks(
270     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
271     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
272     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
273     function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)>
274         ManyInputsAction);
275 
276 /// Compute the demanded elements mask of horizontal binary operations. A
277 /// horizontal operation combines two adjacent elements in a vector operand.
278 /// This function returns a mask for the elements that correspond to the first
279 /// operand of this horizontal combination. For example, for two vectors
280 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
281 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
282 /// result of this function to the left by 1.
283 ///
284 /// \param VectorBitWidth the total bit width of the vector
285 /// \param DemandedElts   the demanded elements mask for the operation
286 /// \param DemandedLHS    the demanded elements mask for the left operand
287 /// \param DemandedRHS    the demanded elements mask for the right operand
288 void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
289                                          const APInt &DemandedElts,
290                                          APInt &DemandedLHS,
291                                          APInt &DemandedRHS);
292 
293 /// Compute a map of integer instructions to their minimum legal type
294 /// size.
295 ///
296 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
297 /// type (e.g. i32) whenever arithmetic is performed on them.
298 ///
299 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
300 /// the arithmetic type down again. However InstCombine refuses to create
301 /// illegal types, so for targets without i8 or i16 registers, the lengthening
302 /// and shrinking remains.
303 ///
304 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
305 /// their scalar equivalents do not, so during vectorization it is important to
306 /// remove these lengthens and truncates when deciding the profitability of
307 /// vectorization.
308 ///
309 /// This function analyzes the given range of instructions and determines the
310 /// minimum type size each can be converted to. It attempts to remove or
311 /// minimize type size changes across each def-use chain, so for example in the
312 /// following code:
313 ///
314 ///   %1 = load i8, i8*
315 ///   %2 = add i8 %1, 2
316 ///   %3 = load i16, i16*
317 ///   %4 = zext i8 %2 to i32
318 ///   %5 = zext i16 %3 to i32
319 ///   %6 = add i32 %4, %5
320 ///   %7 = trunc i32 %6 to i16
321 ///
322 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
323 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
324 ///
325 /// If the optional TargetTransformInfo is provided, this function tries harder
326 /// to do less work by only looking at illegal types.
327 MapVector<Instruction*, uint64_t>
328 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
329                          DemandedBits &DB,
330                          const TargetTransformInfo *TTI=nullptr);
331 
332 /// Compute the union of two access-group lists.
333 ///
334 /// If the list contains just one access group, it is returned directly. If the
335 /// list is empty, returns nullptr.
336 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
337 
338 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
339 /// are both in. If either instruction does not access memory at all, it is
340 /// considered to be in every list.
341 ///
342 /// If the list contains just one access group, it is returned directly. If the
343 /// list is empty, returns nullptr.
344 MDNode *intersectAccessGroups(const Instruction *Inst1,
345                               const Instruction *Inst2);
346 
347 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
348 /// MD_nontemporal, MD_access_group, MD_mmra].
349 /// For K in Kinds, we get the MDNode for K from each of the
350 /// elements of VL, compute their "intersection" (i.e., the most generic
351 /// metadata value that covers all of the individual values), and set I's
352 /// metadata for M equal to the intersection value.
353 ///
354 /// This function always sets a (possibly null) value for each K in Kinds.
355 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
356 
357 /// Create a mask that filters the members of an interleave group where there
358 /// are gaps.
359 ///
360 /// For example, the mask for \p Group with interleave-factor 3
361 /// and \p VF 4, that has only its first member present is:
362 ///
363 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
364 ///
365 /// Note: The result is a mask of 0's and 1's, as opposed to the other
366 /// create[*]Mask() utilities which create a shuffle mask (mask that
367 /// consists of indices).
368 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
369                                const InterleaveGroup<Instruction> &Group);
370 
371 /// Create a mask with replicated elements.
372 ///
373 /// This function creates a shuffle mask for replicating each of the \p VF
374 /// elements in a vector \p ReplicationFactor times. It can be used to
375 /// transform a mask of \p VF elements into a mask of
376 /// \p VF * \p ReplicationFactor elements used by a predicated
377 /// interleaved-group of loads/stores whose Interleaved-factor ==
378 /// \p ReplicationFactor.
379 ///
380 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
381 ///
382 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
383 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
384                                                 unsigned VF);
385 
386 /// Create an interleave shuffle mask.
387 ///
388 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
389 /// vectorization factor \p VF into a single wide vector. The mask is of the
390 /// form:
391 ///
392 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
393 ///
394 /// For example, the mask for VF = 4 and NumVecs = 2 is:
395 ///
396 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
397 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
398 
399 /// Create a stride shuffle mask.
400 ///
401 /// This function creates a shuffle mask whose elements begin at \p Start and
402 /// are incremented by \p Stride. The mask can be used to deinterleave an
403 /// interleaved vector into separate vectors of vectorization factor \p VF. The
404 /// mask is of the form:
405 ///
406 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
407 ///
408 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
409 ///
410 ///   <0, 2, 4, 6>
411 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
412                                             unsigned VF);
413 
414 /// Create a sequential shuffle mask.
415 ///
416 /// This function creates shuffle mask whose elements are sequential and begin
417 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
418 /// NumUndefs undef values. The mask is of the form:
419 ///
420 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
421 ///
422 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
423 ///
424 ///   <0, 1, 2, 3, undef, undef, undef, undef>
425 llvm::SmallVector<int, 16>
426 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
427 
428 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
429 /// mask assuming both operands are identical. This assumes that the unary
430 /// shuffle will use elements from operand 0 (operand 1 will be unused).
431 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
432                                            unsigned NumElts);
433 
434 /// Concatenate a list of vectors.
435 ///
436 /// This function generates code that concatenate the vectors in \p Vecs into a
437 /// single large vector. The number of vectors should be greater than one, and
438 /// their element types should be the same. The number of elements in the
439 /// vectors should also be the same; however, if the last vector has fewer
440 /// elements, it will be padded with undefs.
441 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
442 
443 /// Given a mask vector of i1, Return true if all of the elements of this
444 /// predicate mask are known to be false or undef.  That is, return true if all
445 /// lanes can be assumed inactive.
446 bool maskIsAllZeroOrUndef(Value *Mask);
447 
448 /// Given a mask vector of i1, Return true if all of the elements of this
449 /// predicate mask are known to be true or undef.  That is, return true if all
450 /// lanes can be assumed active.
451 bool maskIsAllOneOrUndef(Value *Mask);
452 
453 /// Given a mask vector of i1, Return true if any of the elements of this
454 /// predicate mask are known to be true or undef.  That is, return true if at
455 /// least one lane can be assumed active.
456 bool maskContainsAllOneOrUndef(Value *Mask);
457 
458 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
459 /// for each lane which may be active.
460 APInt possiblyDemandedEltsInMask(Value *Mask);
461 
462 /// The group of interleaved loads/stores sharing the same stride and
463 /// close to each other.
464 ///
465 /// Each member in this group has an index starting from 0, and the largest
466 /// index should be less than interleaved factor, which is equal to the absolute
467 /// value of the access's stride.
468 ///
469 /// E.g. An interleaved load group of factor 4:
470 ///        for (unsigned i = 0; i < 1024; i+=4) {
471 ///          a = A[i];                           // Member of index 0
472 ///          b = A[i+1];                         // Member of index 1
473 ///          d = A[i+3];                         // Member of index 3
474 ///          ...
475 ///        }
476 ///
477 ///      An interleaved store group of factor 4:
478 ///        for (unsigned i = 0; i < 1024; i+=4) {
479 ///          ...
480 ///          A[i]   = a;                         // Member of index 0
481 ///          A[i+1] = b;                         // Member of index 1
482 ///          A[i+2] = c;                         // Member of index 2
483 ///          A[i+3] = d;                         // Member of index 3
484 ///        }
485 ///
486 /// Note: the interleaved load group could have gaps (missing members), but
487 /// the interleaved store group doesn't allow gaps.
488 template <typename InstTy> class InterleaveGroup {
489 public:
490   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
491       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
492         InsertPos(nullptr) {}
493 
494   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
495       : Alignment(Alignment), InsertPos(Instr) {
496     Factor = std::abs(Stride);
497     assert(Factor > 1 && "Invalid interleave factor");
498 
499     Reverse = Stride < 0;
500     Members[0] = Instr;
501   }
502 
503   bool isReverse() const { return Reverse; }
504   uint32_t getFactor() const { return Factor; }
505   Align getAlign() const { return Alignment; }
506   uint32_t getNumMembers() const { return Members.size(); }
507 
508   /// Try to insert a new member \p Instr with index \p Index and
509   /// alignment \p NewAlign. The index is related to the leader and it could be
510   /// negative if it is the new leader.
511   ///
512   /// \returns false if the instruction doesn't belong to the group.
513   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
514     // Make sure the key fits in an int32_t.
515     std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
516     if (!MaybeKey)
517       return false;
518     int32_t Key = *MaybeKey;
519 
520     // Skip if the key is used for either the tombstone or empty special values.
521     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
522         DenseMapInfo<int32_t>::getEmptyKey() == Key)
523       return false;
524 
525     // Skip if there is already a member with the same index.
526     if (Members.contains(Key))
527       return false;
528 
529     if (Key > LargestKey) {
530       // The largest index is always less than the interleave factor.
531       if (Index >= static_cast<int32_t>(Factor))
532         return false;
533 
534       LargestKey = Key;
535     } else if (Key < SmallestKey) {
536 
537       // Make sure the largest index fits in an int32_t.
538       std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
539       if (!MaybeLargestIndex)
540         return false;
541 
542       // The largest index is always less than the interleave factor.
543       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
544         return false;
545 
546       SmallestKey = Key;
547     }
548 
549     // It's always safe to select the minimum alignment.
550     Alignment = std::min(Alignment, NewAlign);
551     Members[Key] = Instr;
552     return true;
553   }
554 
555   /// Get the member with the given index \p Index
556   ///
557   /// \returns nullptr if contains no such member.
558   InstTy *getMember(uint32_t Index) const {
559     int32_t Key = SmallestKey + Index;
560     return Members.lookup(Key);
561   }
562 
563   /// Get the index for the given member. Unlike the key in the member
564   /// map, the index starts from 0.
565   uint32_t getIndex(const InstTy *Instr) const {
566     for (auto I : Members) {
567       if (I.second == Instr)
568         return I.first - SmallestKey;
569     }
570 
571     llvm_unreachable("InterleaveGroup contains no such member");
572   }
573 
574   InstTy *getInsertPos() const { return InsertPos; }
575   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
576 
577   /// Add metadata (e.g. alias info) from the instructions in this group to \p
578   /// NewInst.
579   ///
580   /// FIXME: this function currently does not add noalias metadata a'la
581   /// addNewMedata.  To do that we need to compute the intersection of the
582   /// noalias info from all members.
583   void addMetadata(InstTy *NewInst) const;
584 
585   /// Returns true if this Group requires a scalar iteration to handle gaps.
586   bool requiresScalarEpilogue() const {
587     // If the last member of the Group exists, then a scalar epilog is not
588     // needed for this group.
589     if (getMember(getFactor() - 1))
590       return false;
591 
592     // We have a group with gaps. It therefore can't be a reversed access,
593     // because such groups get invalidated (TODO).
594     assert(!isReverse() && "Group should have been invalidated");
595 
596     // This is a group of loads, with gaps, and without a last-member
597     return true;
598   }
599 
600 private:
601   uint32_t Factor; // Interleave Factor.
602   bool Reverse;
603   Align Alignment;
604   DenseMap<int32_t, InstTy *> Members;
605   int32_t SmallestKey = 0;
606   int32_t LargestKey = 0;
607 
608   // To avoid breaking dependences, vectorized instructions of an interleave
609   // group should be inserted at either the first load or the last store in
610   // program order.
611   //
612   // E.g. %even = load i32             // Insert Position
613   //      %add = add i32 %even         // Use of %even
614   //      %odd = load i32
615   //
616   //      store i32 %even
617   //      %odd = add i32               // Def of %odd
618   //      store i32 %odd               // Insert Position
619   InstTy *InsertPos;
620 };
621 
622 /// Drive the analysis of interleaved memory accesses in the loop.
623 ///
624 /// Use this class to analyze interleaved accesses only when we can vectorize
625 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
626 /// on interleaved accesses is unsafe.
627 ///
628 /// The analysis collects interleave groups and records the relationships
629 /// between the member and the group in a map.
630 class InterleavedAccessInfo {
631 public:
632   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
633                         DominatorTree *DT, LoopInfo *LI,
634                         const LoopAccessInfo *LAI)
635       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
636 
637   ~InterleavedAccessInfo() { invalidateGroups(); }
638 
639   /// Analyze the interleaved accesses and collect them in interleave
640   /// groups. Substitute symbolic strides using \p Strides.
641   /// Consider also predicated loads/stores in the analysis if
642   /// \p EnableMaskedInterleavedGroup is true.
643   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
644 
645   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
646   /// contrary to original assumption. Although we currently prevent group
647   /// formation for predicated accesses, we may be able to relax this limitation
648   /// in the future once we handle more complicated blocks. Returns true if any
649   /// groups were invalidated.
650   bool invalidateGroups() {
651     if (InterleaveGroups.empty()) {
652       assert(
653           !RequiresScalarEpilogue &&
654           "RequiresScalarEpilog should not be set without interleave groups");
655       return false;
656     }
657 
658     InterleaveGroupMap.clear();
659     for (auto *Ptr : InterleaveGroups)
660       delete Ptr;
661     InterleaveGroups.clear();
662     RequiresScalarEpilogue = false;
663     return true;
664   }
665 
666   /// Check if \p Instr belongs to any interleave group.
667   bool isInterleaved(Instruction *Instr) const {
668     return InterleaveGroupMap.contains(Instr);
669   }
670 
671   /// Get the interleave group that \p Instr belongs to.
672   ///
673   /// \returns nullptr if doesn't have such group.
674   InterleaveGroup<Instruction> *
675   getInterleaveGroup(const Instruction *Instr) const {
676     return InterleaveGroupMap.lookup(Instr);
677   }
678 
679   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
680   getInterleaveGroups() {
681     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
682   }
683 
684   /// Returns true if an interleaved group that may access memory
685   /// out-of-bounds requires a scalar epilogue iteration for correctness.
686   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
687 
688   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
689   /// happen when optimizing for size forbids a scalar epilogue, and the gap
690   /// cannot be filtered by masking the load/store.
691   void invalidateGroupsRequiringScalarEpilogue();
692 
693   /// Returns true if we have any interleave groups.
694   bool hasGroups() const { return !InterleaveGroups.empty(); }
695 
696 private:
697   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
698   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
699   /// The interleaved access analysis can also add new predicates (for example
700   /// by versioning strides of pointers).
701   PredicatedScalarEvolution &PSE;
702 
703   Loop *TheLoop;
704   DominatorTree *DT;
705   LoopInfo *LI;
706   const LoopAccessInfo *LAI;
707 
708   /// True if the loop may contain non-reversed interleaved groups with
709   /// out-of-bounds accesses. We ensure we don't speculatively access memory
710   /// out-of-bounds by executing at least one scalar epilogue iteration.
711   bool RequiresScalarEpilogue = false;
712 
713   /// Holds the relationships between the members and the interleave group.
714   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
715 
716   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
717 
718   /// Holds dependences among the memory accesses in the loop. It maps a source
719   /// access to a set of dependent sink accesses.
720   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
721 
722   /// The descriptor for a strided memory access.
723   struct StrideDescriptor {
724     StrideDescriptor() = default;
725     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
726                      Align Alignment)
727         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
728 
729     // The access's stride. It is negative for a reverse access.
730     int64_t Stride = 0;
731 
732     // The scalar expression of this access.
733     const SCEV *Scev = nullptr;
734 
735     // The size of the memory object.
736     uint64_t Size = 0;
737 
738     // The alignment of this access.
739     Align Alignment;
740   };
741 
742   /// A type for holding instructions and their stride descriptors.
743   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
744 
745   /// Create a new interleave group with the given instruction \p Instr,
746   /// stride \p Stride and alignment \p Align.
747   ///
748   /// \returns the newly created interleave group.
749   InterleaveGroup<Instruction> *
750   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
751     assert(!InterleaveGroupMap.count(Instr) &&
752            "Already in an interleaved access group");
753     InterleaveGroupMap[Instr] =
754         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
755     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
756     return InterleaveGroupMap[Instr];
757   }
758 
759   /// Release the group and remove all the relationships.
760   void releaseGroup(InterleaveGroup<Instruction> *Group) {
761     InterleaveGroups.erase(Group);
762     releaseGroupWithoutRemovingFromSet(Group);
763   }
764 
765   /// Do everything necessary to release the group, apart from removing it from
766   /// the InterleaveGroups set.
767   void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) {
768     for (unsigned i = 0; i < Group->getFactor(); i++)
769       if (Instruction *Member = Group->getMember(i))
770         InterleaveGroupMap.erase(Member);
771 
772     delete Group;
773   }
774 
775   /// Collect all the accesses with a constant stride in program order.
776   void collectConstStrideAccesses(
777       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
778       const DenseMap<Value *, const SCEV *> &Strides);
779 
780   /// Returns true if \p Stride is allowed in an interleaved group.
781   static bool isStrided(int Stride);
782 
783   /// Returns true if \p BB is a predicated block.
784   bool isPredicated(BasicBlock *BB) const {
785     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
786   }
787 
788   /// Returns true if LoopAccessInfo can be used for dependence queries.
789   bool areDependencesValid() const {
790     return LAI && LAI->getDepChecker().getDependences();
791   }
792 
793   /// Returns true if memory accesses \p A and \p B can be reordered, if
794   /// necessary, when constructing interleaved groups.
795   ///
796   /// \p A must precede \p B in program order. We return false if reordering is
797   /// not necessary or is prevented because \p A and \p B may be dependent.
798   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
799                                                  StrideEntry *B) const {
800     // Code motion for interleaved accesses can potentially hoist strided loads
801     // and sink strided stores. The code below checks the legality of the
802     // following two conditions:
803     //
804     // 1. Potentially moving a strided load (B) before any store (A) that
805     //    precedes B, or
806     //
807     // 2. Potentially moving a strided store (A) after any load or store (B)
808     //    that A precedes.
809     //
810     // It's legal to reorder A and B if we know there isn't a dependence from A
811     // to B. Note that this determination is conservative since some
812     // dependences could potentially be reordered safely.
813 
814     // A is potentially the source of a dependence.
815     auto *Src = A->first;
816     auto SrcDes = A->second;
817 
818     // B is potentially the sink of a dependence.
819     auto *Sink = B->first;
820     auto SinkDes = B->second;
821 
822     // Code motion for interleaved accesses can't violate WAR dependences.
823     // Thus, reordering is legal if the source isn't a write.
824     if (!Src->mayWriteToMemory())
825       return true;
826 
827     // At least one of the accesses must be strided.
828     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
829       return true;
830 
831     // If dependence information is not available from LoopAccessInfo,
832     // conservatively assume the instructions can't be reordered.
833     if (!areDependencesValid())
834       return false;
835 
836     // If we know there is a dependence from source to sink, assume the
837     // instructions can't be reordered. Otherwise, reordering is legal.
838     return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
839   }
840 
841   /// Collect the dependences from LoopAccessInfo.
842   ///
843   /// We process the dependences once during the interleaved access analysis to
844   /// enable constant-time dependence queries.
845   void collectDependences() {
846     if (!areDependencesValid())
847       return;
848     const auto &DepChecker = LAI->getDepChecker();
849     auto *Deps = DepChecker.getDependences();
850     for (auto Dep : *Deps)
851       Dependences[Dep.getSource(DepChecker)].insert(
852           Dep.getDestination(DepChecker));
853   }
854 };
855 
856 } // llvm namespace
857 
858 #endif
859