xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (revision cdced8e5bc422a28b42d1bdfb74629cc720a4dfe)
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4, using
11 // wider types, e.g. i8.
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
22 #include "mlir/Dialect/Arith/Utils/Utils.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47 
48 using VectorValue = TypedValue<VectorType>;
49 using MemRefValue = TypedValue<MemRefType>;
50 
51 /// Returns a compressed mask for the emulated vector. For example, when
52 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
53 /// elements span two dest elements), this method compresses `vector<8xi1>`
54 /// into `vector<2xi1>`.
55 ///
56 /// The compressed/output mask value is set iff any mask in the corresponding
57 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
58 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
59 /// following mask:
60 ///
61 ///   %mask = [1, 1, 0, 0, 0, 0]
62 ///
63 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
64 /// will be added in the back to make the number of elements a multiple of
65 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
66 ///
67 ///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
68 ///
69 /// then it will return the following new compressed mask:
70 ///
71 ///   %mask = [1, 1, 0, 0]
72 ///
73 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
74 /// `numSrcElemsPerDest`.
75 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
76                                                   Location loc, Value mask,
77                                                   int numSrcElems,
78                                                   int numSrcElemsPerDest,
79                                                   int numFrontPadElems = 0) {
80 
81   assert(numFrontPadElems < numSrcElemsPerDest &&
82          "numFrontPadElems must be less than numSrcElemsPerDest");
83 
84   auto numDestElems =
85       (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
86       numSrcElemsPerDest;
87 
88   Operation *maskOp = mask.getDefiningOp();
89   SmallVector<vector::ExtractOp, 2> extractOps;
90   // TODO: add support to `vector.splat`.
91   // Finding the mask creation operation.
92   while (maskOp &&
93          !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
94              maskOp)) {
95     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
96       maskOp = extractOp.getVector().getDefiningOp();
97       extractOps.push_back(extractOp);
98     }
99   }
100 
101   if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
102           maskOp))
103     return failure();
104 
105   // Computing the "compressed" mask. All the emulation logic (i.e. computing
106   // new mask index) only happens on the last dimension of the vectors.
107   SmallVector<int64_t> maskShape(
108       cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
109   maskShape.back() = numDestElems;
110   auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
111   std::optional<Operation *> newMask =
112       TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
113           .Case<vector::CreateMaskOp>(
114               [&](auto createMaskOp) -> std::optional<Operation *> {
115                 OperandRange maskOperands = createMaskOp.getOperands();
116                 // The `vector.create_mask` op creates a mask arrangement
117                 // without any zeros at the front. Also, because
118                 // `numFrontPadElems` is strictly smaller than
119                 // `numSrcElemsPerDest`, the compressed mask generated by
120                 // padding the original mask by `numFrontPadElems` will not
121                 // have any zeros at the front as well.
122                 AffineExpr s0;
123                 bindSymbols(rewriter.getContext(), s0);
124                 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
125                 OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
126                 OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
127                     rewriter, loc, s0, origIndex);
128                 SmallVector<Value> newMaskOperands(maskOperands.drop_back());
129                 newMaskOperands.push_back(
130                     getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
131                 return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
132                                                              newMaskOperands);
133               })
134           .Case<vector::ConstantMaskOp>(
135               [&](auto constantMaskOp) -> std::optional<Operation *> {
136                 // Take the shape of mask, compress its trailing dimension:
137                 SmallVector<int64_t> maskDimSizes(
138                     constantMaskOp.getMaskDimSizes());
139                 int64_t &maskIndex = maskDimSizes.back();
140                 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
141                                              numSrcElemsPerDest);
142                 return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
143                                                                maskDimSizes);
144               })
145           .Case<arith::ConstantOp>([&](auto constantOp)
146                                        -> std::optional<Operation *> {
147             // TODO: Support multiple dimensions.
148             if (maskShape.size() != 1)
149               return std::nullopt;
150             // Rearrange the original mask values to cover the whole potential
151             // loading region. For example, in the case of using byte-size for
152             // emulation, given the following mask:
153             //
154             // %mask = [0, 1, 0, 1, 0, 0]
155             //
156             // With front offset of 1, the mask will be padded 0s in the front
157             // and back so that:
158             // 1. It is aligned with the effective loading bits
159             // 2. Its length is multiple of `numSrcElemPerDest` (and the total
160             // coverage size is mulitiple of bytes). The new mask will be like
161             // this before compressing:
162             //
163             // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
164             auto originalMask =
165                 cast<DenseIntElementsAttr>(constantOp.getValue());
166             SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
167             paddedMaskValues.append(originalMask.template value_begin<bool>(),
168                                     originalMask.template value_end<bool>());
169             paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
170 
171             // Compressing by combining every `numSrcElemsPerDest` elements:
172             SmallVector<bool> compressedMaskValues;
173             for (size_t i = 0; i < paddedMaskValues.size();
174                  i += numSrcElemsPerDest) {
175               bool combinedValue = false;
176               for (int j = 0; j < numSrcElemsPerDest; ++j) {
177                 combinedValue |= paddedMaskValues[i + j];
178               }
179               compressedMaskValues.push_back(combinedValue);
180             }
181             return rewriter.create<arith::ConstantOp>(
182                 loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
183           });
184 
185   if (!newMask)
186     return failure();
187 
188   while (!extractOps.empty()) {
189     newMask = rewriter.create<vector::ExtractOp>(
190         loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
191     extractOps.pop_back();
192   }
193 
194   return *newMask;
195 }
196 
197 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
198 /// emitting `vector.extract_strided_slice`.
199 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
200                                         Value source, int64_t frontOffset,
201                                         int64_t subvecSize) {
202   auto vectorType = cast<VectorType>(source.getType());
203   assert(vectorType.getRank() == 1 && "expected 1-D source types");
204   assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
205          "subvector out of bounds");
206 
207   // do not need extraction if the subvector size is the same as the source
208   if (vectorType.getNumElements() == subvecSize)
209     return source;
210 
211   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
212   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
213   auto strides = rewriter.getI64ArrayAttr({1});
214 
215   auto resultVectorType =
216       VectorType::get({subvecSize}, vectorType.getElementType());
217   return rewriter
218       .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
219                                              offsets, sizes, strides)
220       ->getResult(0);
221 }
222 
223 /// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
224 /// at `offset`. it is a wrapper function for emitting
225 /// `vector.insert_strided_slice`.
226 static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
227                                        Value src, Value dest, int64_t offset) {
228   [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
229   [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
230   assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
231          "expected source and dest to be vector type");
232   auto offsets = rewriter.getI64ArrayAttr({offset});
233   auto strides = rewriter.getI64ArrayAttr({1});
234   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
235                                                        dest, offsets, strides);
236 }
237 
238 /// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
239 /// and size `numElementsToExtract`, and inserts into the `dest` vector. This
240 /// function emits multiple `vector.extract` and `vector.insert` ops, so only
241 /// use it when `offset` cannot be folded into a constant value.
242 static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
243                                          Value source, Value dest,
244                                          OpFoldResult offset,
245                                          int64_t numElementsToExtract) {
246   assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
247   for (int i = 0; i < numElementsToExtract; ++i) {
248     Value extractLoc =
249         (i == 0) ? offset.dyn_cast<Value>()
250                  : rewriter.create<arith::AddIOp>(
251                        loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
252                        rewriter.create<arith::ConstantIndexOp>(loc, i));
253     auto extractOp =
254         rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
255     dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
256   }
257   return dest;
258 }
259 
260 /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
261 static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
262                                         Value source, Value dest,
263                                         OpFoldResult destOffsetVar,
264                                         size_t length) {
265   assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
266   assert(length > 0 && "length must be greater than 0");
267   Value destOffsetVal =
268       getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
269   for (size_t i = 0; i < length; ++i) {
270     auto insertLoc = i == 0
271                          ? destOffsetVal
272                          : rewriter.create<arith::AddIOp>(
273                                loc, rewriter.getIndexType(), destOffsetVal,
274                                rewriter.create<arith::ConstantIndexOp>(loc, i));
275     auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
276     dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
277   }
278   return dest;
279 }
280 
281 /// Returns the op sequence for an emulated sub-byte data type vector load.
282 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
283 /// The load location is given by `base` and `linearizedIndices`, and the
284 /// load size is given by `numEmulatedElementsToLoad`.
285 static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
286                                       Value base,
287                                       OpFoldResult linearizedIndices,
288                                       int64_t numEmultedElementsToLoad,
289                                       Type origElemType,
290                                       Type emulatedElemType) {
291   auto scale = emulatedElemType.getIntOrFloatBitWidth() /
292                origElemType.getIntOrFloatBitWidth();
293   auto newLoad = rewriter.create<vector::LoadOp>(
294       loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
295       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
296   return rewriter.create<vector::BitCastOp>(
297       loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
298       newLoad);
299 }
300 
301 /// Downcast two values to `downcastType`, then select values
302 /// based on `mask`, and casts the result to `upcastType`.
303 static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
304                                      VectorType downcastType,
305                                      VectorType upcastType, Value mask,
306                                      Value trueValue, Value falseValue) {
307   assert(
308       downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
309           upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
310       "expected input and output number of bits to match");
311   if (trueValue.getType() != downcastType) {
312     trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
313   }
314   if (falseValue.getType() != downcastType) {
315     falseValue =
316         builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
317   }
318   Value selectedType =
319       builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
320   // Upcast the selected value to the new type.
321   return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
322 }
323 
324 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
325 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
326 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select
327 /// which elements to store.
328 ///
329 /// Inputs:
330 ///   linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
331 ///   storeIdx = 2
332 ///   valueToStore = |3|3|3|3| : vector<4xi2>
333 ///   mask = |0|0|1|1| : vector<4xi1>
334 ///
335 /// Result:
336 ///   linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
337 static void atomicStore(OpBuilder &builder, Location loc,
338                         MemRefValue linearizedMemref, Value storeIdx,
339                         VectorValue valueToStore, Value mask) {
340   assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
341 
342   // Create an atomic load-modify-write region using
343   // `memref.generic_atomic_rmw`.
344   auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
345       loc, linearizedMemref, ValueRange{storeIdx});
346   Value origValue = atomicOp.getCurrentValue();
347 
348   OpBuilder::InsertionGuard guard(builder);
349   builder.setInsertionPointToStart(atomicOp.getBody());
350 
351   // Load the original value from memory, and cast it to the original element
352   // type.
353   auto oneElemVecType = VectorType::get({1}, origValue.getType());
354   Value origVecValue = builder.create<vector::FromElementsOp>(
355       loc, oneElemVecType, ValueRange{origValue});
356 
357   // Construct the final masked value and yield it.
358   Value maskedValue =
359       downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
360                               oneElemVecType, mask, valueToStore, origVecValue);
361   auto scalarMaskedValue =
362       builder.create<vector::ExtractOp>(loc, maskedValue, 0);
363   builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
364 }
365 
366 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
367 /// and insert it into an empty vector at `insertOffset`.
368 /// Inputs:
369 ///   vec_in  = |0|1|2|3| : vector<4xi2>
370 ///   extractOffset = 1
371 ///   sliceNumElements = 2
372 ///   insertOffset = 2
373 /// Output:
374 ///   vec_out = |0|0|1|2| : vector<4xi2>
375 static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
376                                   Location loc, VectorValue vector,
377                                   int64_t extractOffset,
378                                   int64_t sliceNumElements,
379                                   int64_t insertOffset) {
380   assert(vector.getType().getRank() == 1 && "expected 1-D vector");
381   auto vectorElementType = vector.getType().getElementType();
382   // TODO: update and use `alignedConversionPrecondition` in the place of
383   // these asserts.
384   assert(
385       sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
386       "sliceNumElements * vector element size must be less than or equal to 8");
387   assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
388          "vector element must be a valid sub-byte type");
389   auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
390   auto emptyByteVector = rewriter.create<arith::ConstantOp>(
391       loc, VectorType::get({scale}, vectorElementType),
392       rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
393   auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
394                                               extractOffset, sliceNumElements);
395   return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
396                                    insertOffset);
397 }
398 
399 namespace {
400 
401 //===----------------------------------------------------------------------===//
402 // ConvertVectorStore
403 //===----------------------------------------------------------------------===//
404 
405 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
406   using OpConversionPattern::OpConversionPattern;
407 
408   LogicalResult
409   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
410                   ConversionPatternRewriter &rewriter) const override {
411 
412     // See #115653
413     if (op.getValueToStore().getType().getRank() != 1)
414       return rewriter.notifyMatchFailure(op,
415                                          "only 1-D vectors are supported ATM");
416 
417     auto loc = op.getLoc();
418     auto valueToStore = cast<VectorValue>(op.getValueToStore());
419     auto oldElementType = valueToStore.getType().getElementType();
420     auto newElementType =
421         cast<MemRefType>(adaptor.getBase().getType()).getElementType();
422     int srcBits = oldElementType.getIntOrFloatBitWidth();
423     int dstBits = newElementType.getIntOrFloatBitWidth();
424 
425     if (dstBits % srcBits != 0) {
426       return rewriter.notifyMatchFailure(
427           op, "only dstBits % srcBits == 0 supported");
428     }
429     int numSrcElemsPerDest = dstBits / srcBits;
430 
431     // Adjust the number of elements to store when emulating narrow types.
432     // Here only the 1-D vector store is considered, and the N-D memref types
433     // should be linearized.
434     // For example, to emulate i4 to i8, the following op:
435     //
436     // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
437     //
438     // can be replaced with
439     //
440     // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
441     // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
442     // vector<4xi8>
443 
444     auto origElements = valueToStore.getType().getNumElements();
445     bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
446 
447     auto stridedMetadata =
448         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
449 
450     OpFoldResult linearizedIndices;
451     memref::LinearizedMemRefInfo linearizedInfo;
452     std::tie(linearizedInfo, linearizedIndices) =
453         memref::getLinearizedMemRefOffsetAndSize(
454             rewriter, loc, srcBits, dstBits,
455             stridedMetadata.getConstifiedMixedOffset(),
456             stridedMetadata.getConstifiedMixedSizes(),
457             stridedMetadata.getConstifiedMixedStrides(),
458             getAsOpFoldResult(adaptor.getIndices()));
459 
460     std::optional<int64_t> foldedNumFrontPadElems =
461         isAlignedEmulation
462             ? 0
463             : getConstantIntValue(linearizedInfo.intraDataOffset);
464 
465     if (!foldedNumFrontPadElems) {
466       return rewriter.notifyMatchFailure(
467           op, "subbyte store emulation: dynamic front padding size is "
468               "not yet implemented");
469     }
470 
471     auto memrefBase = cast<MemRefValue>(adaptor.getBase());
472 
473     // Conditions when atomic RMWs are not needed:
474     // 1. The source vector size (in bits) is a multiple of byte size.
475     // 2. The address of the store is aligned to the emulated width boundary.
476     //
477     // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
478     // need unaligned emulation because the store address is aligned and the
479     // source is a whole byte.
480     bool emulationRequiresPartialStores =
481         !isAlignedEmulation || *foldedNumFrontPadElems != 0;
482     if (!emulationRequiresPartialStores) {
483       // Basic case: storing full bytes.
484       auto numElements = origElements / numSrcElemsPerDest;
485       auto bitCast = rewriter.create<vector::BitCastOp>(
486           loc, VectorType::get(numElements, newElementType),
487           op.getValueToStore());
488       rewriter.replaceOpWithNewOp<vector::StoreOp>(
489           op, bitCast.getResult(), memrefBase,
490           getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
491       return success();
492     }
493 
494     // Next, handle the case when sub-byte read-modify-write
495     // sequences are needed to emulate a vector store.
496     // Here is an example:
497     //
498     // Vector to store: vector<7xi2>
499     // Value to store: 11 11 11 11 11 11 11 (all ones)
500     //
501     // Destination: memref<12xi2>
502     // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
503     //
504     // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
505     //
506     // Destination memref before:
507     //
508     //    Byte 0     Byte 1     Byte 2
509     // +----------+----------+----------+
510     // | 00000000 | 00000000 | 00000000 |
511     // +----------+----------+----------+
512     //
513     // Destination memref after:
514     //
515     //    Byte 0     Byte 1     Byte 2
516     // +----------+----------+----------+
517     // | 00001111 | 11111111 | 11000000 |
518     // +----------+----------+----------+
519     //
520     // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
521     // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
522     // requiring RMW access (atomicity is required).
523 
524     // The index into the target memref we are storing to.
525     Value currentDestIndex =
526         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
527     // The index into the source vector we are currently processing.
528     auto currentSourceIndex = 0;
529 
530     // Build a mask used for rmw.
531     auto subWidthStoreMaskType =
532         VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
533 
534     // 1. Partial width store for the leading byte.
535     // When the store address is not aligned to emulated width boundary, deal
536     // with the unaligned part so that the rest elements are aligned to width
537     // boundary.
538     auto frontSubWidthStoreElem =
539         (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
540     if (frontSubWidthStoreElem > 0) {
541       SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
542       if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
543         std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
544                     origElements, true);
545         frontSubWidthStoreElem = origElements;
546       } else {
547         std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
548                     *foldedNumFrontPadElems, true);
549       }
550       auto frontMask = rewriter.create<arith::ConstantOp>(
551           loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
552 
553       currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
554       auto value =
555           extractSliceIntoByte(rewriter, loc, valueToStore, 0,
556                                frontSubWidthStoreElem, *foldedNumFrontPadElems);
557 
558       atomicStore(rewriter, loc, memrefBase, currentDestIndex,
559                   cast<VectorValue>(value), frontMask.getResult());
560     }
561 
562     if (currentSourceIndex >= origElements) {
563       rewriter.eraseOp(op);
564       return success();
565     }
566 
567     // Increment the destination index by 1 to align to the emulated width
568     // boundary.
569     auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
570     currentDestIndex = rewriter.create<arith::AddIOp>(
571         loc, rewriter.getIndexType(), currentDestIndex, constantOne);
572 
573     // 2. Full width store for the inner output bytes.
574     // After the previous step, the store address is aligned to the emulated
575     // width boundary.
576     int64_t fullWidthStoreSize =
577         (origElements - currentSourceIndex) / numSrcElemsPerDest;
578     int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
579     if (fullWidthStoreSize > 0) {
580       auto fullWidthStorePart = staticallyExtractSubvector(
581           rewriter, loc, valueToStore, currentSourceIndex,
582           numNonFullWidthElements);
583 
584       auto originType = cast<VectorType>(fullWidthStorePart.getType());
585       auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
586       auto storeType = VectorType::get(
587           {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
588       auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
589                                                         fullWidthStorePart);
590       rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
591                                        currentDestIndex);
592 
593       currentSourceIndex += numNonFullWidthElements;
594       currentDestIndex = rewriter.create<arith::AddIOp>(
595           loc, rewriter.getIndexType(), currentDestIndex,
596           rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
597     }
598 
599     // 3. Partial width store for the trailing output byte.
600     // It is needed when the residual length is smaller than the emulated width,
601     // which is not covered in step 2 above.
602     auto remainingElements = origElements - currentSourceIndex;
603     if (remainingElements != 0) {
604       auto subWidthStorePart =
605           extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
606                                currentSourceIndex, remainingElements, 0);
607 
608       // Generate back mask.
609       auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
610       std::fill_n(maskValues.begin(), remainingElements, 1);
611       auto backMask = rewriter.create<arith::ConstantOp>(
612           loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
613 
614       atomicStore(rewriter, loc, memrefBase, currentDestIndex,
615                   cast<VectorValue>(subWidthStorePart), backMask.getResult());
616     }
617 
618     rewriter.eraseOp(op);
619     return success();
620   }
621 };
622 
623 //===----------------------------------------------------------------------===//
624 // ConvertVectorMaskedStore
625 //===----------------------------------------------------------------------===//
626 
627 struct ConvertVectorMaskedStore final
628     : OpConversionPattern<vector::MaskedStoreOp> {
629   using OpConversionPattern::OpConversionPattern;
630 
631   LogicalResult
632   matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
633                   ConversionPatternRewriter &rewriter) const override {
634 
635     // See #115653
636     if (op.getValueToStore().getType().getRank() != 1)
637       return rewriter.notifyMatchFailure(op,
638                                          "only 1-D vectors are supported ATM");
639 
640     auto loc = op.getLoc();
641     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
642     Type oldElementType = op.getValueToStore().getType().getElementType();
643     Type newElementType = convertedType.getElementType();
644     int srcBits = oldElementType.getIntOrFloatBitWidth();
645     int dstBits = newElementType.getIntOrFloatBitWidth();
646 
647     if (dstBits % srcBits != 0) {
648       return rewriter.notifyMatchFailure(
649           op, "only dstBits % srcBits == 0 supported");
650     }
651 
652     int scale = dstBits / srcBits;
653     int origElements = op.getValueToStore().getType().getNumElements();
654     if (origElements % scale != 0)
655       return failure();
656 
657     auto stridedMetadata =
658         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
659     OpFoldResult linearizedIndicesOfr;
660     memref::LinearizedMemRefInfo linearizedInfo;
661     std::tie(linearizedInfo, linearizedIndicesOfr) =
662         memref::getLinearizedMemRefOffsetAndSize(
663             rewriter, loc, srcBits, dstBits,
664             stridedMetadata.getConstifiedMixedOffset(),
665             stridedMetadata.getConstifiedMixedSizes(),
666             stridedMetadata.getConstifiedMixedStrides(),
667             getAsOpFoldResult(adaptor.getIndices()));
668     Value linearizedIndices =
669         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
670 
671     // Load the whole data and use arith.select to handle the corner cases.
672     //
673     // As an example, for this masked store of i4 values:
674     //
675     //   vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
676     //
677     // and given these input values:
678     //
679     //   %mask = [0, 1, 1, 1, 1, 0, 0, 0]                     (8 * i1)
680     //   %0[%c0, %c0] =
681     //      [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]          (8 * i4)
682     //   %val_to_store =
683     //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]          (8 * i4)
684     //
685     // we'll have the following i4 output:
686     //
687     //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
688     //
689     // Emulating the above using i8 will give:
690     //
691     //    %compressed_mask = [1, 1, 1, 0]                     (4 * i1)
692     //    %maskedload = [0x12, 0x34, 0x56, 0x00]              (4 * i8)
693     //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
694     //    %select_using_shifted_mask =
695     //      [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0]          (8 * i4)
696     //    %packed_data = [0x1A, 0xBC, 0xD6, 0x00]             (4 * i8)
697     //
698     // Using the compressed mask to store %packed_data results in expected
699     // output.
700     //
701     // FIXME: Make an example based on the comment above work (see #115460 for
702     // reproducer).
703     FailureOr<Operation *> newMask =
704         getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
705     if (failed(newMask))
706       return failure();
707 
708     auto numElements = (origElements + scale - 1) / scale;
709     auto newType = VectorType::get(numElements, newElementType);
710     auto passThru = rewriter.create<arith::ConstantOp>(
711         loc, newType, rewriter.getZeroAttr(newType));
712 
713     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
714         loc, newType, adaptor.getBase(), linearizedIndices,
715         newMask.value()->getResult(0), passThru);
716 
717     auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
718     Value valueToStore =
719         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
720     valueToStore = rewriter.create<arith::SelectOp>(
721         loc, op.getMask(), op.getValueToStore(), valueToStore);
722     valueToStore =
723         rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
724 
725     rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
726         op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
727         valueToStore);
728     return success();
729   }
730 };
731 
732 //===----------------------------------------------------------------------===//
733 // ConvertVectorLoad
734 //===----------------------------------------------------------------------===//
735 
736 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
737   using OpConversionPattern::OpConversionPattern;
738 
739   LogicalResult
740   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
741                   ConversionPatternRewriter &rewriter) const override {
742 
743     // See #115653
744     if (op.getVectorType().getRank() != 1)
745       return rewriter.notifyMatchFailure(op,
746                                          "only 1-D vectors are supported ATM");
747 
748     auto loc = op.getLoc();
749     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
750     Type oldElementType = op.getType().getElementType();
751     Type newElementType = convertedType.getElementType();
752     int srcBits = oldElementType.getIntOrFloatBitWidth();
753     int dstBits = newElementType.getIntOrFloatBitWidth();
754 
755     if (dstBits % srcBits != 0) {
756       return rewriter.notifyMatchFailure(
757           op, "only dstBits % srcBits == 0 supported");
758     }
759     int scale = dstBits / srcBits;
760 
761     // Adjust the number of elements to load when emulating narrow types,
762     // and then cast back to the original type with vector.bitcast op.
763     // Here only the 1-D vector load is considered, and the N-D memref types
764     // should be linearized.
765     // For example, to emulate i4 to i8, the following op:
766     //
767     // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
768     //
769     // can be replaced with
770     //
771     // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
772     // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
773     //
774     // There are cases where the number of elements to load is not byte-aligned,
775     // for example:
776     //
777     // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
778     //
779     // we will have to load extra bytes and extract the exact slice in between.
780     //
781     // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
782     // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
783     // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
784     // = [1]}
785     //        : vector<8xi2> to vector<3xi2>
786     //
787     // TODO: Currently the extract_strided_slice's attributes must be known at
788     // compile time as they must be constants.
789 
790     auto origElements = op.getVectorType().getNumElements();
791     bool isAlignedEmulation = origElements % scale == 0;
792 
793     auto stridedMetadata =
794         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
795 
796     OpFoldResult linearizedIndices;
797     memref::LinearizedMemRefInfo linearizedInfo;
798     std::tie(linearizedInfo, linearizedIndices) =
799         memref::getLinearizedMemRefOffsetAndSize(
800             rewriter, loc, srcBits, dstBits,
801             stridedMetadata.getConstifiedMixedOffset(),
802             stridedMetadata.getConstifiedMixedSizes(),
803             stridedMetadata.getConstifiedMixedStrides(),
804             getAsOpFoldResult(adaptor.getIndices()));
805 
806     std::optional<int64_t> foldedIntraVectorOffset =
807         isAlignedEmulation
808             ? 0
809             : getConstantIntValue(linearizedInfo.intraDataOffset);
810 
811     // Always load enough elements which can cover the original elements.
812     int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
813     auto numElements =
814         llvm::divideCeil(maxintraDataOffset + origElements, scale);
815     Value result =
816         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
817                            numElements, oldElementType, newElementType);
818 
819     if (!foldedIntraVectorOffset) {
820       auto resultVector = rewriter.create<arith::ConstantOp>(
821           loc, op.getType(), rewriter.getZeroAttr(op.getType()));
822       result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector,
823                                            linearizedInfo.intraDataOffset,
824                                            origElements);
825     } else if (!isAlignedEmulation) {
826       result = staticallyExtractSubvector(
827           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
828     }
829     rewriter.replaceOp(op, result);
830     return success();
831   }
832 };
833 
834 //===----------------------------------------------------------------------===//
835 // ConvertVectorMaskedLoad
836 //===----------------------------------------------------------------------===//
837 
838 struct ConvertVectorMaskedLoad final
839     : OpConversionPattern<vector::MaskedLoadOp> {
840   using OpConversionPattern::OpConversionPattern;
841 
842   LogicalResult
843   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
844                   ConversionPatternRewriter &rewriter) const override {
845     // See #115653
846     if (op.getVectorType().getRank() != 1)
847       return rewriter.notifyMatchFailure(op,
848                                          "only 1-D vectors are supported ATM");
849 
850     auto loc = op.getLoc();
851     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
852     Type oldElementType = op.getType().getElementType();
853     Type newElementType = convertedType.getElementType();
854     int srcBits = oldElementType.getIntOrFloatBitWidth();
855     int dstBits = newElementType.getIntOrFloatBitWidth();
856 
857     if (dstBits % srcBits != 0) {
858       return rewriter.notifyMatchFailure(
859           op, "only dstBits % srcBits == 0 supported");
860     }
861     int scale = dstBits / srcBits;
862 
863     // Adjust the number of elements to load when emulating narrow types,
864     // and then cast back to the original type with vector.bitcast op.
865     // For example, to emulate i4 to i8, the following op:
866     //
867     //   %mask = vector.constant_mask [3] : vector<6xi1>
868     //   %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
869     //        memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
870     //
871     // can be replaced with
872     //
873     //   %new_mask = vector.constant_mask [2] : vector<3xi1>
874     //   %new_pass_thru = vector.bitcast %pass_thru :
875     //        vector<6xi4> to vector<3xi8>
876     //   %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
877     //        memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
878     //   %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
879     //
880     // Since we are effectively loading 16 bits (2xi8) from the memref with the
881     // new mask, while originally we only wanted to effectively load 12 bits
882     // (3xi4) from the memref, we need to set the second half of the last i8
883     // that was effectively loaded (i.e. the second i8) to %pass_thru.
884     //
885     //   %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
886     //
887     // Given these input values:
888     //   %mask = [1, 1, 1, 0, 0, 0]
889     //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
890     //   %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
891     //
892     // we'll have:
893     //
894     //   expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
895     //
896     //   %new_mask = [1, 1, 0]
897     //   %new_pass_thru = [0x78, 0x9A, 0xBC]
898     //   %1 = [0x12, 0x34, 0xBC]
899     //   %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
900     //   %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
901     //
902     // TODO: Currently, only the even number of elements loading is supported.
903     // To deal with the odd number of elements, one has to extract the
904     // subvector at the proper offset after bit-casting.
905     auto origType = op.getVectorType();
906     auto origElements = origType.getNumElements();
907     bool isAlignedEmulation = origElements % scale == 0;
908 
909     auto stridedMetadata =
910         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
911     OpFoldResult linearizedIndices;
912     memref::LinearizedMemRefInfo linearizedInfo;
913     std::tie(linearizedInfo, linearizedIndices) =
914         memref::getLinearizedMemRefOffsetAndSize(
915             rewriter, loc, srcBits, dstBits,
916             stridedMetadata.getConstifiedMixedOffset(),
917             stridedMetadata.getConstifiedMixedSizes(),
918             stridedMetadata.getConstifiedMixedStrides(),
919             getAsOpFoldResult(adaptor.getIndices()));
920 
921     std::optional<int64_t> foldedIntraVectorOffset =
922         isAlignedEmulation
923             ? 0
924             : getConstantIntValue(linearizedInfo.intraDataOffset);
925 
926     int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
927     FailureOr<Operation *> newMask = getCompressedMaskOp(
928         rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
929     if (failed(newMask))
930       return failure();
931 
932     Value passthru = op.getPassThru();
933 
934     auto numElements =
935         llvm::divideCeil(maxIntraDataOffset + origElements, scale);
936     auto loadType = VectorType::get(numElements, newElementType);
937     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
938 
939     auto emptyVector = rewriter.create<arith::ConstantOp>(
940         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
941     if (!foldedIntraVectorOffset) {
942       passthru = dynamicallyInsertSubVector(
943           rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
944           origElements);
945     } else if (!isAlignedEmulation) {
946       passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
947                                            *foldedIntraVectorOffset);
948     }
949     auto newPassThru =
950         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
951 
952     // Generating the new masked load.
953     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
954         loc, loadType, adaptor.getBase(),
955         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
956         newMask.value()->getResult(0), newPassThru);
957 
958     // Setting the part that originally was not effectively loaded from memory
959     // to pass through.
960     auto bitCast =
961         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
962 
963     Value mask = op.getMask();
964     auto newSelectMaskType =
965         VectorType::get(numElements * scale, rewriter.getI1Type());
966     // TODO: try to fold if op's mask is constant
967     auto emptyMask = rewriter.create<arith::ConstantOp>(
968         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
969     if (!foldedIntraVectorOffset) {
970       mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
971                                         linearizedInfo.intraDataOffset,
972                                         origElements);
973     } else if (!isAlignedEmulation) {
974       mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
975                                        *foldedIntraVectorOffset);
976     }
977 
978     Value result =
979         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
980     if (!foldedIntraVectorOffset) {
981       result = dynamicallyExtractSubVector(
982           rewriter, loc, result, op.getPassThru(),
983           linearizedInfo.intraDataOffset, origElements);
984     } else if (!isAlignedEmulation) {
985       result = staticallyExtractSubvector(
986           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
987     }
988     rewriter.replaceOp(op, result);
989 
990     return success();
991   }
992 };
993 
994 //===----------------------------------------------------------------------===//
995 // ConvertVectorTransferRead
996 //===----------------------------------------------------------------------===//
997 
998 struct ConvertVectorTransferRead final
999     : OpConversionPattern<vector::TransferReadOp> {
1000   using OpConversionPattern::OpConversionPattern;
1001 
1002   LogicalResult
1003   matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1004                   ConversionPatternRewriter &rewriter) const override {
1005 
1006     // See #115653
1007     if (op.getVectorType().getRank() != 1)
1008       return rewriter.notifyMatchFailure(op,
1009                                          "only 1-D vectors are supported ATM");
1010 
1011     auto loc = op.getLoc();
1012     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
1013     Type oldElementType = op.getType().getElementType();
1014     Type newElementType = convertedType.getElementType();
1015     int srcBits = oldElementType.getIntOrFloatBitWidth();
1016     int dstBits = newElementType.getIntOrFloatBitWidth();
1017 
1018     if (dstBits % srcBits != 0) {
1019       return rewriter.notifyMatchFailure(
1020           op, "only dstBits % srcBits == 0 supported");
1021     }
1022     int scale = dstBits / srcBits;
1023 
1024     auto origElements = op.getVectorType().getNumElements();
1025 
1026     bool isAlignedEmulation = origElements % scale == 0;
1027 
1028     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
1029                                                       adaptor.getPadding());
1030 
1031     auto stridedMetadata =
1032         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
1033 
1034     OpFoldResult linearizedIndices;
1035     memref::LinearizedMemRefInfo linearizedInfo;
1036     std::tie(linearizedInfo, linearizedIndices) =
1037         memref::getLinearizedMemRefOffsetAndSize(
1038             rewriter, loc, srcBits, dstBits,
1039             stridedMetadata.getConstifiedMixedOffset(),
1040             stridedMetadata.getConstifiedMixedSizes(),
1041             stridedMetadata.getConstifiedMixedStrides(),
1042             getAsOpFoldResult(adaptor.getIndices()));
1043 
1044     std::optional<int64_t> foldedIntraVectorOffset =
1045         isAlignedEmulation
1046             ? 0
1047             : getConstantIntValue(linearizedInfo.intraDataOffset);
1048 
1049     int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
1050     auto numElements =
1051         llvm::divideCeil(maxIntraDataOffset + origElements, scale);
1052 
1053     auto newRead = rewriter.create<vector::TransferReadOp>(
1054         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
1055         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1056         newPadding);
1057 
1058     auto bitCast = rewriter.create<vector::BitCastOp>(
1059         loc, VectorType::get(numElements * scale, oldElementType), newRead);
1060 
1061     Value result = bitCast->getResult(0);
1062     if (!foldedIntraVectorOffset) {
1063       auto zeros = rewriter.create<arith::ConstantOp>(
1064           loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1065       result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1066                                            linearizedInfo.intraDataOffset,
1067                                            origElements);
1068     } else if (!isAlignedEmulation) {
1069       result = staticallyExtractSubvector(
1070           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1071     }
1072     rewriter.replaceOp(op, result);
1073 
1074     return success();
1075   }
1076 };
1077 } // end anonymous namespace
1078 
1079 //===----------------------------------------------------------------------===//
1080 // RewriteBitCastOfTruncI
1081 //===----------------------------------------------------------------------===//
1082 
1083 namespace {
1084 
1085 /// Helper struct to keep track of the provenance of a contiguous set of bits
1086 /// in a source vector.
1087 struct SourceElementRange {
1088   /// The index of the source vector element that contributes bits to *this.
1089   int64_t sourceElementIdx;
1090   /// The range of bits in the source vector element that contribute to *this.
1091   int64_t sourceBitBegin;
1092   int64_t sourceBitEnd;
1093 };
1094 
1095 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1096   /// Given the index of a SourceElementRange in the SourceElementRangeList,
1097   /// compute the amount of bits that need to be shifted to the left to get the
1098   /// bits in their final location. This shift amount is simply the sum of the
1099   /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1100   /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1101   int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1102     int64_t res = 0;
1103     for (int64_t i = 0; i < shuffleIdx; ++i)
1104       res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1105     return res;
1106   }
1107 };
1108 
1109 /// Helper struct to enumerate the source elements and bit ranges that are
1110 /// involved in a bitcast operation.
1111 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1112 /// any 1-D vector shape and any source/target bitwidths.
1113 /// This creates and holds a mapping of the form:
1114 /// [dstVectorElementJ] ==
1115 ///    [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1116 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1117 ///   [0] = {0, [0-8)}
1118 ///   [1] = {0, [8-16)}
1119 ///   [2] = {0, [16-24)}
1120 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1121 ///   [0] = {0, [0, 10)}, {1, [0, 5)}
1122 ///   [1] = {1, [5, 10)}, {2, [0, 10)}
1123 struct BitCastBitsEnumerator {
1124   BitCastBitsEnumerator(VectorType sourceVectorType,
1125                         VectorType targetVectorType);
1126 
1127   int64_t getMaxNumberOfEntries() {
1128     int64_t numVectors = 0;
1129     for (const auto &l : sourceElementRanges)
1130       numVectors = std::max(numVectors, (int64_t)l.size());
1131     return numVectors;
1132   }
1133 
1134   VectorType sourceVectorType;
1135   VectorType targetVectorType;
1136   SmallVector<SourceElementRangeList> sourceElementRanges;
1137 };
1138 
1139 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1140 /// advantage of high-level information to avoid leaving LLVM to scramble with
1141 /// peephole optimizations.
1142 /// BitCastBitsEnumerator encodes for each element of the target vector the
1143 /// provenance of the bits in the source vector. We can "transpose" this
1144 /// information to build a sequence of shuffles and bitwise ops that will
1145 /// produce the desired result.
1146 //
1147 /// Consider the following motivating example:
1148 /// ```
1149 ///   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1150 /// ```
1151 //
1152 /// BitCastBitsEnumerator contains the following information:
1153 /// ```
1154 ///   { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1155 ///   { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1156 ///   { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1157 ///   { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1158 ///   { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1159 ///   { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1160 ///   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1161 ///   {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1162 ///   {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1163 ///   {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1164 ///   {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1165 ///   {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1166 ///   {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1167 ///   {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1168 ///   {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1169 ///   {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1170 ///   {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1171 ///   {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1172 ///   {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1173 ///   {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1174 /// ```
1175 ///
1176 /// In the above, each row represents one target vector element and each
1177 /// column represents one bit contribution from a source vector element.
1178 /// The algorithm creates vector.shuffle operations (in this case there are 3
1179 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1180 /// algorithm populates the bits as follows:
1181 /// ```
1182 ///     src bits 0 ...
1183 /// 1st shuffle |xxxxx   |xx      |...
1184 /// 2nd shuffle |     xxx|  xxxxx |...
1185 /// 3rd shuffle |        |       x|...
1186 /// ```
1187 //
1188 /// The algorithm proceeds as follows:
1189 ///   1. for each vector.shuffle, collect the source vectors that participate in
1190 ///     this shuffle. One source vector per target element of the resulting
1191 ///     vector.shuffle. If there is no source element contributing bits for the
1192 ///     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1193 ///     2 columns).
1194 ///   2. represent the bitrange in the source vector as a mask. If there is no
1195 ///     source element contributing bits for the current vector.shuffle, take 0.
1196 ///   3. shift right by the proper amount to align the source bitrange at
1197 ///     position 0. This is exactly the low end of the bitrange. For instance,
1198 ///     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1199 ///     shift right by 3 to get the bits contributed by the source element #1
1200 ///     into position 0.
1201 ///   4. shift left by the proper amount to to align to the desired position in
1202 ///     the result element vector.  For instance, the contribution of the second
1203 ///     source element for the first row needs to be shifted by `5` to form the
1204 ///     first i8 result element.
1205 ///
1206 /// Eventually, we end up building  the sequence
1207 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1208 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1209 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1210 struct BitCastRewriter {
1211   /// Helper metadata struct to hold the static quantities for the rewrite.
1212   struct Metadata {
1213     SmallVector<int64_t> shuffles;
1214     SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1215   };
1216 
1217   BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1218 
1219   /// Verify that general preconditions for the rewrite are met.
1220   LogicalResult commonPrecondition(PatternRewriter &rewriter,
1221                                    VectorType preconditionType, Operation *op);
1222 
1223   /// Precompute the metadata for the rewrite.
1224   SmallVector<BitCastRewriter::Metadata>
1225   precomputeMetadata(IntegerType shuffledElementType);
1226 
1227   /// Rewrite one step of the sequence:
1228   ///   `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1229   Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1230                            Value initialValue, Value runningResult,
1231                            const BitCastRewriter::Metadata &metadata);
1232 
1233 private:
1234   /// Underlying enumerator that encodes the provenance of the bits in the each
1235   /// element of the result vector.
1236   BitCastBitsEnumerator enumerator;
1237 };
1238 
1239 } // namespace
1240 
1241 [[maybe_unused]] static raw_ostream &
1242 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1243   for (const auto &l : vec) {
1244     for (auto it : llvm::enumerate(l)) {
1245       os << "{ " << it.value().sourceElementIdx << ": b@["
1246          << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1247          << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1248     }
1249     os << "\n";
1250   }
1251   return os;
1252 }
1253 
1254 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1255                                              VectorType targetVectorType)
1256     : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1257 
1258   assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1259          "requires -D non-scalable vector type");
1260   assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1261          "requires -D non-scalable vector type");
1262   int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1263   int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1264   LDBG("sourceVectorType: " << sourceVectorType);
1265 
1266   int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1267   int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1268   LDBG("targetVectorType: " << targetVectorType);
1269 
1270   int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1271   (void)mostMinorSourceDim;
1272   assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1273          "source and target bitwidths must match");
1274 
1275   // Prepopulate one source element range per target element.
1276   sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1277   for (int64_t resultBit = 0; resultBit < bitwidth;) {
1278     int64_t resultElement = resultBit / targetBitWidth;
1279     int64_t resultBitInElement = resultBit % targetBitWidth;
1280     int64_t sourceElementIdx = resultBit / sourceBitWidth;
1281     int64_t sourceBitInElement = resultBit % sourceBitWidth;
1282     int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1283                             targetBitWidth - resultBitInElement);
1284     sourceElementRanges[resultElement].push_back(
1285         {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1286     resultBit += step;
1287   }
1288 }
1289 
1290 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1291                                  VectorType targetVectorType)
1292     : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1293   LDBG("\n" << enumerator.sourceElementRanges);
1294 }
1295 
1296 /// Verify that the precondition type meets the common preconditions for any
1297 /// conversion.
1298 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1299                                                   VectorType preconditionType,
1300                                                   Operation *op) {
1301   if (!preconditionType || preconditionType.isScalable())
1302     return rewriter.notifyMatchFailure(op, "scalable vector");
1303 
1304   // TODO: consider relaxing this restriction in the future if we find ways
1305   // to really work with subbyte elements across the MLIR/LLVM boundary.
1306   unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1307   if (bitwidth % 8 != 0)
1308     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1309 
1310   return success();
1311 }
1312 
1313 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1314                                                   VectorType preconditionType,
1315                                                   Operation *op) {
1316   if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1317     return rewriter.notifyMatchFailure(op, "types are not vector");
1318 
1319   if (!preconditionType || preconditionType.getRank() != 1)
1320     return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1321 
1322   return commonConversionPrecondition(rewriter, preconditionType, op);
1323 }
1324 
1325 /// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1326 /// means that:
1327 ///   1. The `dstType` element type is a multiple of the
1328 ///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1329 ///   is not supported). Let this multiple be `N`.
1330 ///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1331 ///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1332 ///   not supported).
1333 ///
1334 /// NOTE: This method assumes that common conversion preconditions are met. In
1335 /// particular, the element type of `dstType` is assumed to be a multi-byte
1336 /// type (e.g. i8, i16, i32).
1337 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1338                                                    VectorType subByteVecType,
1339                                                    VectorType dstType,
1340                                                    Operation *op) {
1341   if (!subByteVecType || !dstType)
1342     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1343   unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1344   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1345 
1346   if (dstElemBitwidth < 8)
1347     return rewriter.notifyMatchFailure(
1348         op, "the bitwidth of dstType must be greater than or equal to 8");
1349   if (dstElemBitwidth % srcElemBitwidth != 0)
1350     return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1351   if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1352     return rewriter.notifyMatchFailure(
1353         op, "only src bitwidth of 2 or 4 is supported at this moment");
1354 
1355   const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1356   if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1357     return rewriter.notifyMatchFailure(
1358         op, "the trailing dimension of the input vector of sub-bytes must be a "
1359             "multiple of 8 / <sub-byte-width>");
1360 
1361   return success();
1362 }
1363 
1364 SmallVector<BitCastRewriter::Metadata>
1365 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1366   SmallVector<BitCastRewriter::Metadata> result;
1367   for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1368        shuffleIdx < e; ++shuffleIdx) {
1369     SmallVector<int64_t> shuffles;
1370     SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1371 
1372     // Create the attribute quantities for the shuffle / mask / shift ops.
1373     for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1374       int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1375                                   ? srcEltRangeList[shuffleIdx].sourceElementIdx
1376                                   : 0;
1377       shuffles.push_back(sourceElement);
1378 
1379       int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1380                           ? srcEltRangeList[shuffleIdx].sourceBitBegin
1381                           : 0;
1382       int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1383                           ? srcEltRangeList[shuffleIdx].sourceBitEnd
1384                           : 0;
1385       IntegerAttr mask = IntegerAttr::get(
1386           shuffledElementType,
1387           llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1388                                   bitLo, bitHi));
1389       masks.push_back(mask);
1390 
1391       int64_t shiftRight = bitLo;
1392       shiftRightAmounts.push_back(
1393           IntegerAttr::get(shuffledElementType, shiftRight));
1394 
1395       int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1396       shiftLeftAmounts.push_back(
1397           IntegerAttr::get(shuffledElementType, shiftLeft));
1398     }
1399 
1400     result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1401   }
1402   return result;
1403 }
1404 
1405 Value BitCastRewriter::genericRewriteStep(
1406     PatternRewriter &rewriter, Location loc, Value initialValue,
1407     Value runningResult, const BitCastRewriter::Metadata &metadata) {
1408   // Create vector.shuffle from the metadata.
1409   auto shuffleOp = rewriter.create<vector::ShuffleOp>(
1410       loc, initialValue, initialValue, metadata.shuffles);
1411 
1412   // Intersect with the mask.
1413   VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1414   auto constOp = rewriter.create<arith::ConstantOp>(
1415       loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1416   Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
1417 
1418   // Align right on 0.
1419   auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
1420       loc,
1421       DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1422   Value shiftedRight =
1423       rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1424 
1425   // Shift bits left into their final position.
1426   auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
1427       loc,
1428       DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1429   Value shiftedLeft =
1430       rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1431 
1432   runningResult =
1433       runningResult
1434           ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1435           : shiftedLeft;
1436 
1437   return runningResult;
1438 }
1439 
1440 /// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1441 /// Where aligned means it satisfies the alignedConversionPreconditions.
1442 ///
1443 /// Example:
1444 /// vector<16x16xi2> -> vector<16x4xi8>
1445 /// vector<16x16xi4> -> vector<16x8xi8>
1446 static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
1447                                       Value subByteVec) {
1448   auto srcVecType = cast<VectorType>(subByteVec.getType());
1449   int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1450   assert(8 % srcBitwidth == 0 &&
1451          "Unsupported sub-byte type (not a divisor of i8)");
1452   int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1453   SmallVector<int64_t> vecShape(srcVecType.getShape());
1454   // Adjust last dimension of the vector, so the total size remains the same.
1455   vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1456   auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1457   return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1458 }
1459 
1460 /// Extracts a signed N-bit sequence from each element of a vector of bytes,
1461 /// starting at the specified bit index.
1462 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1463 ///
1464 /// Example for a single element:
1465 /// Extract numBits=2 starting at bitIdx=2
1466 /// src     = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1467 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1468 /// target  = [.   .   .   .   ^   ^   .   .]
1469 ///
1470 /// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1471 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1472 ///
1473 /// src     =                         [01 01 11 10]
1474 /// shl     = arith.shl(src, 4)    -> [11 10 00 00]
1475 /// result  = arith.shrsi(shl, 6)  -> [11 11 11 11]
1476 static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
1477                                                   Location loc, Value src,
1478                                                   int bitIdx, int numBits) {
1479   auto srcType = cast<VectorType>(src.getType());
1480   Value shl = src;
1481   int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1482   assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1483          "Invalid bitIdx range");
1484   if (bitsToShiftLeft != 0) {
1485     Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
1486         loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1487     shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
1488   }
1489 
1490   int8_t bitsToShiftRight = 8 - numBits;
1491   Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1492       loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1493   Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1494   return shr;
1495 }
1496 
1497 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1498 /// starting at the specified bit index.
1499 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1500 ///
1501 /// Example for a single element:
1502 /// Extract numBits=2 starting at bitIdx=2
1503 /// src     = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1504 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1505 /// target  = [.   .   .   .   ^   ^   .   .]
1506 ///
1507 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1508 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1509 ///
1510 /// src                            = [01 01 10 10]
1511 /// mask                           = [00 00 00 11]
1512 /// shr    = arith.shrui(src, 2)   = [00 01 01 10]
1513 /// result = arith.andi(shr, mask) = [00 00 00 10]
1514 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1515 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1516 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1517 /// left when the index is 0.
1518 static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
1519                                               Location loc, Value src,
1520                                               int bitIdx, int numBits) {
1521   assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1522          "Invalid bitIdx range");
1523   auto srcType = cast<VectorType>(src.getType());
1524   int8_t bitsToShiftRight = bitIdx;
1525   Value shr = src;
1526   if (bitsToShiftRight != 0) {
1527     Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1528         loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1529     shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
1530   }
1531   if (bitIdx + numBits == 8) {
1532     return shr;
1533   }
1534   uint8_t lowBitsMask = (1 << numBits) - 1;
1535   Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1536       loc, DenseElementsAttr::get(srcType, lowBitsMask));
1537   return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1538 }
1539 
1540 using ExtractNBitsFn =
1541     std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1542 
1543 /// Rewrite the i4 -> i8  extension into a sequence of shuffles and
1544 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1545 static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
1546                               Value srcValue, const ExtractNBitsFn &extFn) {
1547   [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1548   assert(srcVecType.getElementType().isSignlessInteger(4) &&
1549          "Expected i4 type");
1550 
1551   // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1552   Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1553 
1554   // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1555   // byte are place in one vector and the high i4 elements in another vector.
1556   Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1557   Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1558 
1559   // 3. Interleave low and high i8 elements.
1560   return rewriter.create<vector::InterleaveOp>(loc, low, high);
1561 }
1562 
1563 /// Rewrite the i2 -> i8  extension into a sequence of shuffles and
1564 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1565 static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
1566                               Value srcValue, const ExtractNBitsFn &extFn) {
1567   [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1568   assert(srcVecType.getElementType().isSignlessInteger(2) &&
1569          "Expected i2 type");
1570 
1571   // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1572   Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1573 
1574   // 2. Extract each i2 element
1575   // Positon 0 (bits 0-1)
1576   Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1577   // Position 1 (bits 2-3)
1578   Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1579   // Position 2 (bits 4-5)
1580   Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1581   // Position 3 (bits 6-7)
1582   Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1583 
1584   // 3. Interleave all 4 elements by first interleaving
1585   // even elements and then odd
1586   // vec0  = [0,0,0,0],...
1587   // vec1  = [1,1,1,1],...
1588   // vec2  = [2,2,2,2],...
1589   // vec3  = [3,3,3,3],...
1590   // 02    = [0,2,0,2,0,2,0,2],...
1591   // 13    = [1,3,1,3,1,3,1,3],...
1592   // 0213  = [0,1,2,3,...],...
1593   Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
1594   Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
1595   return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
1596 }
1597 
1598 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1599 /// ops to avoid leaving LLVM to scramble with peephole optimizations.
1600 static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
1601                                 Value srcValue) {
1602   VectorType srcVecType = cast<VectorType>(srcValue.getType());
1603   assert(srcVecType.getElementType().isSignlessInteger(8) &&
1604          "Expected i8 type");
1605 
1606   // 1. De-interleave low and high i8 elements.
1607   auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
1608 
1609   // 2. Zero out the upper side of each low i8 element.
1610   constexpr int8_t i8LowBitMask = 0x0F;
1611   VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1612   Value zeroOutMask = rewriter.create<arith::ConstantOp>(
1613       loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1614   Value zeroOutLow = rewriter.create<arith::AndIOp>(
1615       loc, deinterleaveOp.getRes1(), zeroOutMask);
1616 
1617   // 3. Move high i4 values to upper side of the byte.
1618   constexpr int8_t bitsToShift = 4;
1619   auto shiftValues = rewriter.create<arith::ConstantOp>(
1620       loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1621   Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1622                                                  shiftValues);
1623 
1624   // 4. Merge high and low i4 values.
1625   auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1626 
1627   // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1628   auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1629   return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1630 }
1631 
1632 namespace {
1633 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1634 /// advantage of high-level information to avoid leaving LLVM to scramble with
1635 /// peephole optimizations.
1636 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1637   using OpRewritePattern::OpRewritePattern;
1638 
1639   LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1640                                 PatternRewriter &rewriter) const override {
1641     // The source must be a trunc op.
1642     auto truncOp =
1643         bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1644     if (!truncOp)
1645       return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1646 
1647     // Set up the BitCastRewriter and verify the precondition.
1648     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1649     VectorType targetVectorType = bitCastOp.getResultVectorType();
1650     BitCastRewriter bcr(sourceVectorType, targetVectorType);
1651     if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1652       return failure();
1653 
1654     // Perform the rewrite.
1655     Value truncValue = truncOp.getIn();
1656     auto shuffledElementType =
1657         cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1658     Value runningResult;
1659     for (const BitCastRewriter ::Metadata &metadata :
1660          bcr.precomputeMetadata(shuffledElementType)) {
1661       runningResult = bcr.genericRewriteStep(
1662           rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1663     }
1664 
1665     // Finalize the rewrite.
1666     bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1667                      shuffledElementType.getIntOrFloatBitWidth();
1668     if (narrowing) {
1669       if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1670         rewriter.replaceOp(bitCastOp, runningResult);
1671       } else {
1672         rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1673             bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1674       }
1675     } else {
1676       if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1677         rewriter.replaceOp(bitCastOp, runningResult);
1678       } else {
1679         rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1680             bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1681       }
1682     }
1683 
1684     return success();
1685   }
1686 };
1687 } // namespace
1688 
1689 //===----------------------------------------------------------------------===//
1690 // RewriteExtOfBitCast
1691 //===----------------------------------------------------------------------===//
1692 
1693 namespace {
1694 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1695 /// take advantage of high-level information to avoid leaving LLVM to scramble
1696 /// with peephole optimizations.
1697 template <typename ExtOpType>
1698 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1699   using OpRewritePattern<ExtOpType>::OpRewritePattern;
1700 
1701   RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1702       : OpRewritePattern<ExtOpType>(context, benefit) {}
1703 
1704   LogicalResult matchAndRewrite(ExtOpType extOp,
1705                                 PatternRewriter &rewriter) const override {
1706     // The source must be a bitcast op.
1707     auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1708     if (!bitCastOp)
1709       return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1710 
1711     // Set up the BitCastRewriter and verify the precondition.
1712     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1713     VectorType targetVectorType = bitCastOp.getResultVectorType();
1714     BitCastRewriter bcr(sourceVectorType, targetVectorType);
1715     if (failed(bcr.commonPrecondition(
1716             rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1717       return failure();
1718 
1719     // Perform the rewrite.
1720     Value runningResult;
1721     Value sourceValue = bitCastOp.getSource();
1722     auto shuffledElementType =
1723         cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1724     for (const BitCastRewriter::Metadata &metadata :
1725          bcr.precomputeMetadata(shuffledElementType)) {
1726       runningResult = bcr.genericRewriteStep(
1727           rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1728     }
1729 
1730     // Finalize the rewrite.
1731     bool narrowing =
1732         cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1733         shuffledElementType.getIntOrFloatBitWidth();
1734     if (narrowing) {
1735       rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1736           extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1737     } else {
1738       rewriter.replaceOpWithNewOp<ExtOpType>(
1739           extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1740     }
1741 
1742     return success();
1743   }
1744 };
1745 
1746 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1747 /// bitwise ops that take advantage of high-level information to avoid leaving
1748 /// LLVM to scramble with peephole optimizations. Templated to choose between
1749 /// signed and unsigned conversions.
1750 ///
1751 /// For example (signed):
1752 ///    arith.extsi %in : vector<8xi4> to vector<8xi32>
1753 ///      is rewriten as
1754 ///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1755 ///        %1 = arith.shli %0, 4 : vector<4xi8>
1756 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
1757 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
1758 ///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1759 ///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1760 ///
1761 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
1762 ///      is rewriten as
1763 ///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1764 ///        %1 = arith.shli %0, 4 : vector<4xi8>
1765 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
1766 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
1767 ///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1768 ///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1769 ///
1770 /// Example (unsigned):
1771 ///    arith.extui %in : vector<8xi4> to vector<8xi32>
1772 ///      is rewritten as
1773 ///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1774 ///        %1 = arith.andi %0, 15 : vector<4xi8>
1775 ///        %2 = arith.shrui %0, 4 : vector<4xi8>
1776 ///        %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1777 ///        %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1778 ///
1779 template <typename ConversionOpType, bool isSigned>
1780 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1781   using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1782 
1783   LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1784                                 PatternRewriter &rewriter) const override {
1785     // Verify the preconditions.
1786     Value srcValue = conversionOp.getIn();
1787     auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1788     auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1789 
1790     if (failed(
1791             commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1792       return failure();
1793 
1794     // Check general alignment preconditions.
1795     if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1796                                              conversionOp)))
1797       return failure();
1798 
1799     // Perform the rewrite.
1800     Location loc = conversionOp.getLoc();
1801     const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
1802                                  : extractNBitsPerByteAndExtendToI8;
1803     Value subByteExt;
1804     switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1805     case 2:
1806       subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
1807       break;
1808     case 4:
1809       subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
1810       break;
1811     default:
1812       return failure();
1813     }
1814 
1815     // Finalize the rewrite.
1816     rewriter.replaceOpWithNewOp<ConversionOpType>(
1817         conversionOp, conversionOp.getType(), subByteExt);
1818     return success();
1819   }
1820 };
1821 
1822 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1823 /// bitwise ops that take advantage of high-level information to avoid leaving
1824 /// LLVM to scramble with peephole optimizations.
1825 ///
1826 /// For example:
1827 ///    arith.trunci %in : vector<8xi32> to vector<8xi4>
1828 ///      is rewriten as
1829 ///
1830 ///        %cst = arith.constant dense<15> : vector<4xi8>
1831 ///        %cst_0 = arith.constant dense<4> : vector<4xi8>
1832 ///        %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1833 ///        %2 = arith.andi %0, %cst : vector<4xi8>
1834 ///        %3 = arith.shli %1, %cst_0 : vector<4xi8>
1835 ///        %4 = arith.ori %2, %3 : vector<4xi8>
1836 ///        %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1837 ///
1838 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1839   using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
1840 
1841   LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1842                                 PatternRewriter &rewriter) const override {
1843     // Verify the preconditions.
1844     Value srcValue = truncOp.getIn();
1845     auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1846     auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1847     if (!srcVecType || !dstVecType)
1848       return failure();
1849 
1850     if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1851       return failure();
1852 
1853     // TODO: Add support for truncating to i2.
1854     if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
1855       return failure();
1856 
1857     // Check general alignment preconditions. We invert the src/dst type order
1858     // to reuse the existing precondition logic.
1859     if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1860                                              truncOp)))
1861       return failure();
1862 
1863     // Create a new iX -> i8 truncation op.
1864     Location loc = truncOp.getLoc();
1865     auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1866     Value i8TruncVal =
1867         rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1868 
1869     // Rewrite the i8 -> i4 truncation part.
1870     Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1871 
1872     // Finalize the rewrite.
1873     rewriter.replaceOp(truncOp, subByteTrunc);
1874     return success();
1875   }
1876 };
1877 
1878 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1879 /// perform the transpose on wider (byte) element types.
1880 /// For example:
1881 ///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1882 ///
1883 ///   is rewritten as:
1884 ///
1885 ///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1886 ///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1887 ///   %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1888 ///
1889 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1890   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
1891 
1892   RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1893       : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1894 
1895   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1896                                 PatternRewriter &rewriter) const override {
1897     // Precondition: sub-byte integer transpose.
1898     constexpr unsigned minNativeBitwidth = 8;
1899     VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1900     if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1901         srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1902       return rewriter.notifyMatchFailure(transposeOp,
1903                                          "not a sub-byte transpose");
1904     }
1905 
1906     // Perform the rewrite.
1907     Location loc = transposeOp.getLoc();
1908     // Signed/unsigned interpretation shouldn't matter here as we are just
1909     // transposing the elements and truncating them back to the original size.
1910     // TODO: Use unsigned extension (more efficient) when emulation or backend
1911     // support is available.
1912     auto srcNativeVecType = srcSubByteVecType.cloneWith(
1913         std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1914     Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1915                                                   transposeOp.getVector());
1916     Value newTranspose = rewriter.create<vector::TransposeOp>(
1917         loc, extOp, transposeOp.getPermutation());
1918     VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1919     rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1920                                                  newTranspose);
1921     return success();
1922   }
1923 };
1924 
1925 } // namespace
1926 
1927 //===----------------------------------------------------------------------===//
1928 // Public Interface Definition
1929 //===----------------------------------------------------------------------===//
1930 
1931 void vector::populateVectorNarrowTypeEmulationPatterns(
1932     const arith::NarrowTypeEmulationConverter &typeConverter,
1933     RewritePatternSet &patterns) {
1934 
1935   // Populate `vector.*` conversion patterns.
1936   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1937                ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1938       typeConverter, patterns.getContext());
1939 }
1940 
1941 void vector::populateVectorNarrowTypeRewritePatterns(
1942     RewritePatternSet &patterns, PatternBenefit benefit) {
1943   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1944                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1945                                                     benefit);
1946 
1947   // Patterns for aligned cases. We set higher priority as they are expected to
1948   // generate better performance for aligned cases.
1949   patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1950                RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
1951                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1952                                               benefit.getBenefit() + 1);
1953   patterns
1954       .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
1955            RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
1956           patterns.getContext(), benefit.getBenefit() + 1);
1957 }
1958 
1959 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
1960     RewritePatternSet &patterns, PatternBenefit benefit) {
1961   patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1962 }
1963