xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision bd5d361c059814435bab24189e79e01d94c7039d)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
29   auto resultTypes = op->getResultTypes();
30   for (auto resType : resultTypes) {
31     VectorType vecType = dyn_cast<VectorType>(resType);
32     // Reject index since getElementTypeBitWidth will abort for Index types.
33     if (!vecType || vecType.getElementType().isIndex())
34       return false;
35     // There are no dimension to fold if it is a 0-D vector.
36     if (vecType.getRank() == 0)
37       return false;
38     unsigned trailingVecDimBitWidth =
39         vecType.getShape().back() * vecType.getElementTypeBitWidth();
40     if (trailingVecDimBitWidth >= targetBitWidth)
41       return false;
42   }
43   return true;
44 }
45 
46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
47   VectorType vecType = dyn_cast<VectorType>(t);
48   // Reject index since getElementTypeBitWidth will abort for Index types.
49   if (!vecType || vecType.getElementType().isIndex())
50     return false;
51   // There are no dimension to fold if it is a 0-D vector.
52   if (vecType.getRank() == 0)
53     return false;
54   unsigned trailingVecDimBitWidth =
55       vecType.getShape().back() * vecType.getElementTypeBitWidth();
56   return trailingVecDimBitWidth <= targetBitWidth;
57 }
58 
59 namespace {
60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61   using OpConversionPattern::OpConversionPattern;
62   LinearizeConstant(
63       const TypeConverter &typeConverter, MLIRContext *context,
64       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
65       PatternBenefit benefit = 1)
66       : OpConversionPattern(typeConverter, context, benefit),
67         targetVectorBitWidth(targetVectBitWidth) {}
68   LogicalResult
69   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
70                   ConversionPatternRewriter &rewriter) const override {
71     Location loc = constOp.getLoc();
72     auto resType =
73         getTypeConverter()->convertType<VectorType>(constOp.getType());
74 
75     if (!resType)
76       return rewriter.notifyMatchFailure(loc, "can't convert return type");
77 
78     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
79       return rewriter.notifyMatchFailure(
80           loc,
81           "Cannot linearize a constant scalable vector that's not a splat");
82 
83     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
84       return rewriter.notifyMatchFailure(
85           loc, "Can't flatten since targetBitWidth <= OpSize");
86     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
87     if (!dstElementsAttr)
88       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
89 
90     dstElementsAttr = dstElementsAttr.reshape(resType);
91     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
92                                                    dstElementsAttr);
93     return success();
94   }
95 
96 private:
97   unsigned targetVectorBitWidth;
98 };
99 
100 struct LinearizeVectorizable final
101     : OpTraitConversionPattern<OpTrait::Vectorizable> {
102   using OpTraitConversionPattern::OpTraitConversionPattern;
103 
104 public:
105   LinearizeVectorizable(
106       const TypeConverter &typeConverter, MLIRContext *context,
107       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
108       PatternBenefit benefit = 1)
109       : OpTraitConversionPattern(typeConverter, context, benefit),
110         targetVectorBitWidth(targetVectBitWidth) {}
111   LogicalResult
112   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
113                   ConversionPatternRewriter &rewriter) const override {
114     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
115       return rewriter.notifyMatchFailure(
116           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
117     FailureOr<Operation *> newOp =
118         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
119     if (failed(newOp))
120       return failure();
121 
122     rewriter.replaceOp(op, (*newOp)->getResults());
123     return success();
124   }
125 
126 private:
127   unsigned targetVectorBitWidth;
128 };
129 
130 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
131 /// on a linearized vector.
132 /// Following,
133 ///   vector.extract_strided_slice %source
134 ///         { offsets = [..], strides = [..], sizes = [..] }
135 /// is converted to :
136 ///   %source_1d = vector.shape_cast %source
137 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
138 ///   %out_nd = vector.shape_cast %out_1d
139 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
140 /// extraction.
141 struct LinearizeVectorExtractStridedSlice final
142     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
143   using OpConversionPattern::OpConversionPattern;
144   LinearizeVectorExtractStridedSlice(
145       const TypeConverter &typeConverter, MLIRContext *context,
146       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
147       PatternBenefit benefit = 1)
148       : OpConversionPattern(typeConverter, context, benefit),
149         targetVectorBitWidth(targetVectBitWidth) {}
150 
151   LogicalResult
152   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
153                   ConversionPatternRewriter &rewriter) const override {
154     VectorType dstType =
155         getTypeConverter()->convertType<VectorType>(extractOp.getType());
156     assert(dstType && "vector type destination expected.");
157     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
158       return rewriter.notifyMatchFailure(extractOp,
159                                          "scalable vectors are not supported.");
160     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
161       return rewriter.notifyMatchFailure(
162           extractOp, "Can't flatten since targetBitWidth <= OpSize");
163 
164     ArrayAttr offsets = extractOp.getOffsets();
165     ArrayAttr sizes = extractOp.getSizes();
166     ArrayAttr strides = extractOp.getStrides();
167     if (!isConstantIntValue(strides[0], 1))
168       return rewriter.notifyMatchFailure(
169           extractOp, "Strided slice with stride != 1 is not supported.");
170     Value srcVector = adaptor.getVector();
171     // If kD offsets are specified for nD source vector (n > k), the granularity
172     // of the extraction is greater than 1. In this case last (n-k) dimensions
173     // form the extraction granularity.
174     // Example :
175     //  vector.extract_strided_slice %src {
176     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
177     //      vector<4x8x8xf32> to vector<2x2x8xf32>
178     // Here, extraction granularity is 8.
179     int64_t extractGranularitySize = 1;
180     int64_t nD = extractOp.getSourceVectorType().getRank();
181     int64_t kD = (int64_t)offsets.size();
182     int64_t k = kD;
183     while (k < nD) {
184       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
185       ++k;
186     }
187     // Get total number of extracted slices.
188     int64_t nExtractedSlices = 1;
189     for (Attribute size : sizes) {
190       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
191     }
192     // Compute the strides of the source vector considering first k dimensions.
193     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
194     for (int i = kD - 2; i >= 0; --i) {
195       sourceStrides[i] = sourceStrides[i + 1] *
196                          extractOp.getSourceVectorType().getShape()[i + 1];
197     }
198     // Final shuffle indices has nExtractedSlices * extractGranularitySize
199     // elements.
200     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
201                                           extractGranularitySize);
202     // Compute the strides of the extracted kD vector.
203     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
204     // Compute extractedStrides.
205     for (int i = kD - 2; i >= 0; --i) {
206       extractedStrides[i] =
207           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
208     }
209     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
210     // and compute the multi-dimensional index and the corresponding linearized
211     // index within the source vector.
212     for (int64_t i = 0; i < nExtractedSlices; ++i) {
213       int64_t index = i;
214       // Compute the corresponding multi-dimensional index.
215       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
216       for (int64_t j = 0; j < kD; ++j) {
217         multiDimIndex[j] = (index / extractedStrides[j]);
218         index -= multiDimIndex[j] * extractedStrides[j];
219       }
220       // Compute the corresponding linearized index in the source vector
221       // i.e. shift the multiDimIndex by the offsets.
222       int64_t linearizedIndex = 0;
223       for (int64_t j = 0; j < kD; ++j) {
224         linearizedIndex +=
225             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
226             sourceStrides[j];
227       }
228       // Fill the indices array form linearizedIndex to linearizedIndex +
229       // extractGranularitySize.
230       for (int64_t j = 0; j < extractGranularitySize; ++j) {
231         indices[i * extractGranularitySize + j] = linearizedIndex + j;
232       }
233     }
234     // Perform a shuffle to extract the kD vector.
235     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
236         extractOp, dstType, srcVector, srcVector, indices);
237     return success();
238   }
239 
240 private:
241   unsigned targetVectorBitWidth;
242 };
243 
244 /// This pattern converts the ShuffleOp that works on nD (n > 1)
245 /// vectors to a ShuffleOp that works on linearized vectors.
246 /// Following,
247 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
248 /// is converted to :
249 ///   %v1_1d = vector.shape_cast %v1
250 ///   %v2_1d = vector.shape_cast %v2
251 ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
252 ///   %out_nd = vector.shape_cast %out_1d
253 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
254 /// of the original shuffle operation.
255 struct LinearizeVectorShuffle final
256     : public OpConversionPattern<vector::ShuffleOp> {
257   using OpConversionPattern::OpConversionPattern;
258   LinearizeVectorShuffle(
259       const TypeConverter &typeConverter, MLIRContext *context,
260       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
261       PatternBenefit benefit = 1)
262       : OpConversionPattern(typeConverter, context, benefit),
263         targetVectorBitWidth(targetVectBitWidth) {}
264 
265   LogicalResult
266   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
267                   ConversionPatternRewriter &rewriter) const override {
268     VectorType dstType =
269         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
270     assert(dstType && "vector type destination expected.");
271     // The assert is used because vector.shuffle does not support scalable
272     // vectors.
273     assert(!(shuffleOp.getV1VectorType().isScalable() ||
274              shuffleOp.getV2VectorType().isScalable() ||
275              dstType.isScalable()) &&
276            "scalable vectors are not supported.");
277     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
278       return rewriter.notifyMatchFailure(
279           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
280 
281     Value vec1 = adaptor.getV1();
282     Value vec2 = adaptor.getV2();
283     int shuffleSliceLen = 1;
284     int rank = shuffleOp.getV1().getType().getRank();
285 
286     // If rank > 1, we need to do the shuffle in the granularity of slices
287     // instead of scalars. Size of the slice is equal to the rank-1 innermost
288     // dims. Mask of the shuffle op specifies which slice to take from the
289     // outermost dim.
290     if (rank > 1) {
291       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
292       for (unsigned i = 1; i < shape.size(); ++i) {
293         shuffleSliceLen *= shape[i];
294       }
295     }
296 
297     // For each value in the mask, we generate the indices of the source vectors
298     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
299     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
300     // elements) instead of scalars.
301     ArrayRef<int64_t> mask = shuffleOp.getMask();
302     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
303     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
304     for (auto [i, value] : llvm::enumerate(mask)) {
305       std::iota(indices.begin() + shuffleSliceLen * i,
306                 indices.begin() + shuffleSliceLen * (i + 1),
307                 shuffleSliceLen * value);
308     }
309 
310     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
311                                                    vec2, indices);
312     return success();
313   }
314 
315 private:
316   unsigned targetVectorBitWidth;
317 };
318 
319 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
320 /// linearized vector.
321 /// Following,
322 ///   vector.extract %source [ position ]
323 /// is converted to :
324 ///   %source_1d = vector.shape_cast %source
325 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
326 ///   %out_nd = vector.shape_cast %out_1d
327 /// `shuffle_indices_1d` is computed using the position of the original extract.
328 struct LinearizeVectorExtract final
329     : public OpConversionPattern<vector::ExtractOp> {
330   using OpConversionPattern::OpConversionPattern;
331   LinearizeVectorExtract(
332       const TypeConverter &typeConverter, MLIRContext *context,
333       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
334       PatternBenefit benefit = 1)
335       : OpConversionPattern(typeConverter, context, benefit),
336         targetVectorBitWidth(targetVectBitWidth) {}
337   LogicalResult
338   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
339                   ConversionPatternRewriter &rewriter) const override {
340     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
341     if (!dstTy)
342       return rewriter.notifyMatchFailure(extractOp,
343                                          "expected n-D vector type.");
344 
345     if (extractOp.getVector().getType().isScalable() ||
346         cast<VectorType>(dstTy).isScalable())
347       return rewriter.notifyMatchFailure(extractOp,
348                                          "scalable vectors are not supported.");
349     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
350       return rewriter.notifyMatchFailure(
351           extractOp, "Can't flatten since targetBitWidth <= OpSize");
352 
353     // Dynamic position is not supported.
354     if (extractOp.hasDynamicPosition())
355       return rewriter.notifyMatchFailure(extractOp,
356                                          "dynamic position is not supported.");
357 
358     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
359     int64_t size = extractOp.getVector().getType().getNumElements();
360 
361     // Compute linearized offset.
362     int64_t linearizedOffset = 0;
363     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
364     for (auto [i, off] : llvm::enumerate(offsets)) {
365       size /= shape[i];
366       linearizedOffset += offsets[i] * size;
367     }
368 
369     llvm::SmallVector<int64_t, 2> indices(size);
370     std::iota(indices.begin(), indices.end(), linearizedOffset);
371     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
372         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
373 
374     return success();
375   }
376 
377 private:
378   unsigned targetVectorBitWidth;
379 };
380 
381 /// This pattern converts the InsertOp to a ShuffleOp that works on a
382 /// linearized vector.
383 /// Following,
384 ///   vector.insert %source %destination [ position ]
385 /// is converted to :
386 ///   %source_1d = vector.shape_cast %source
387 ///   %destination_1d = vector.shape_cast %destination
388 ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
389 ///   ] %out_nd = vector.shape_cast %out_1d
390 /// `shuffle_indices_1d` is computed using the position of the original insert.
391 struct LinearizeVectorInsert final
392     : public OpConversionPattern<vector::InsertOp> {
393   using OpConversionPattern::OpConversionPattern;
394   LinearizeVectorInsert(
395       const TypeConverter &typeConverter, MLIRContext *context,
396       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
397       PatternBenefit benefit = 1)
398       : OpConversionPattern(typeConverter, context, benefit),
399         targetVectorBitWidth(targetVectBitWidth) {}
400   LogicalResult
401   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
402                   ConversionPatternRewriter &rewriter) const override {
403     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
404         insertOp.getDestVectorType());
405     assert(dstTy && "vector type destination expected.");
406     if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
407       return rewriter.notifyMatchFailure(insertOp,
408                                          "scalable vectors are not supported.");
409 
410     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
411                                          targetVectorBitWidth))
412       return rewriter.notifyMatchFailure(
413           insertOp, "Can't flatten since targetBitWidth < OpSize");
414 
415     // dynamic position is not supported
416     if (insertOp.hasDynamicPosition())
417       return rewriter.notifyMatchFailure(insertOp,
418                                          "dynamic position is not supported.");
419     auto srcTy = insertOp.getSourceType();
420     auto srcAsVec = dyn_cast<VectorType>(srcTy);
421     uint64_t srcSize = 0;
422     if (srcAsVec) {
423       srcSize = srcAsVec.getNumElements();
424     } else {
425       return rewriter.notifyMatchFailure(insertOp,
426                                          "scalars are not supported.");
427     }
428 
429     auto dstShape = insertOp.getDestVectorType().getShape();
430     const auto dstSize = insertOp.getDestVectorType().getNumElements();
431     auto dstSizeForOffsets = dstSize;
432 
433     // compute linearized offset
434     int64_t linearizedOffset = 0;
435     auto offsetsNd = insertOp.getStaticPosition();
436     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
437       dstSizeForOffsets /= dstShape[dim];
438       linearizedOffset += offset * dstSizeForOffsets;
439     }
440 
441     llvm::SmallVector<int64_t, 2> indices(dstSize);
442     auto origValsUntil = indices.begin();
443     std::advance(origValsUntil, linearizedOffset);
444     std::iota(indices.begin(), origValsUntil,
445               0); // original values that remain [0, offset)
446     auto newValsUntil = origValsUntil;
447     std::advance(newValsUntil, srcSize);
448     std::iota(origValsUntil, newValsUntil,
449               dstSize); // new values [offset, offset+srcNumElements)
450     std::iota(newValsUntil, indices.end(),
451               linearizedOffset + srcSize); // the rest of original values
452                                            // [offset+srcNumElements, end)
453 
454     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
455         insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
456 
457     return success();
458   }
459 
460 private:
461   unsigned targetVectorBitWidth;
462 };
463 
464 /// This pattern converts the BitCastOp that works on nD (n > 1)
465 /// vectors to a BitCastOp that works on linearized vectors.
466 /// Following,
467 ///   vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
468 /// is converted to :
469 ///   %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
470 ///   %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
471 ///   %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
472 struct LinearizeVectorBitCast final
473     : public OpConversionPattern<vector::BitCastOp> {
474   using OpConversionPattern::OpConversionPattern;
475   LinearizeVectorBitCast(
476       const TypeConverter &typeConverter, MLIRContext *context,
477       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
478       PatternBenefit benefit = 1)
479       : OpConversionPattern(typeConverter, context, benefit),
480         targetVectorBitWidth(targetVectBitWidth) {}
481   LogicalResult
482   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
483                   ConversionPatternRewriter &rewriter) const override {
484     Location loc = castOp.getLoc();
485     auto resType = getTypeConverter()->convertType(castOp.getType());
486     if (!resType)
487       return rewriter.notifyMatchFailure(loc, "can't convert return type.");
488 
489     if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
490       return rewriter.notifyMatchFailure(
491           loc, "Can't flatten since targetBitWidth <= OpSize");
492 
493     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
494                                                    adaptor.getSource());
495     return mlir::success();
496   }
497 
498 private:
499   unsigned targetVectorBitWidth;
500 };
501 
502 } // namespace
503 
504 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
505     TypeConverter &typeConverter, RewritePatternSet &patterns,
506     ConversionTarget &target, unsigned targetBitWidth) {
507 
508   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
509     if (!isLinearizableVector(type))
510       return type;
511 
512     return VectorType::get(type.getNumElements(), type.getElementType(),
513                            type.isScalable());
514   });
515 
516   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
517                             Location loc) -> Value {
518     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
519         !isa<VectorType>(type))
520       return nullptr;
521 
522     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
523   };
524   typeConverter.addSourceMaterialization(materializeCast);
525   typeConverter.addTargetMaterialization(materializeCast);
526   target.markUnknownOpDynamicallyLegal(
527       [=](Operation *op) -> std::optional<bool> {
528         if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
529              op->hasTrait<OpTrait::Vectorizable>())) {
530           return (isLessThanTargetBitWidth(op, targetBitWidth)
531                       ? typeConverter.isLegal(op)
532                       : true);
533         }
534         return std::nullopt;
535       });
536 
537   patterns
538       .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539           typeConverter, patterns.getContext(), targetBitWidth);
540 }
541 
542 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
543     const TypeConverter &typeConverter, RewritePatternSet &patterns,
544     ConversionTarget &target, unsigned int targetBitWidth) {
545   target.addDynamicallyLegalOp<vector::ShuffleOp>(
546       [=](vector::ShuffleOp shuffleOp) -> bool {
547         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
548                    ? (typeConverter.isLegal(shuffleOp) &&
549                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
550                               .getRank() == 1)
551                    : true;
552       });
553   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
554                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
555       typeConverter, patterns.getContext(), targetBitWidth);
556 }
557