xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision bd5d361c059814435bab24189e79e01d94c7039d)
135ef3994SIvan Butygin //===- VectorLinearize.cpp - vector linearization transforms --------------===//
235ef3994SIvan Butygin //
335ef3994SIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
435ef3994SIvan Butygin // See https://llvm.org/LICENSE.txt for license information.
535ef3994SIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
635ef3994SIvan Butygin //
735ef3994SIvan Butygin //===----------------------------------------------------------------------===//
835ef3994SIvan Butygin //
935ef3994SIvan Butygin // This file implements patterns and pass for linearizing ND vectors into 1D.
1035ef3994SIvan Butygin //
1135ef3994SIvan Butygin //===----------------------------------------------------------------------===//
1235ef3994SIvan Butygin 
1335ef3994SIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h"
1435ef3994SIvan Butygin #include "mlir/Dialect/Vector/IR/VectorOps.h"
1535ef3994SIvan Butygin #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16c577f91dSCharitha Saumya #include "mlir/IR/Attributes.h"
17c577f91dSCharitha Saumya #include "mlir/IR/BuiltinAttributes.h"
18c577f91dSCharitha Saumya #include "mlir/IR/Operation.h"
1935ef3994SIvan Butygin #include "mlir/IR/PatternMatch.h"
2035ef3994SIvan Butygin #include "mlir/IR/TypeUtilities.h"
2135ef3994SIvan Butygin #include "mlir/Transforms/DialectConversion.h"
22c577f91dSCharitha Saumya #include "llvm/ADT/ArrayRef.h"
23c577f91dSCharitha Saumya #include <cstdint>
24c577f91dSCharitha Saumya #include <numeric>
2535ef3994SIvan Butygin 
2635ef3994SIvan Butygin using namespace mlir;
2735ef3994SIvan Butygin 
286f5c4f2eSBalaji V. Iyer static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
296f5c4f2eSBalaji V. Iyer   auto resultTypes = op->getResultTypes();
306f5c4f2eSBalaji V. Iyer   for (auto resType : resultTypes) {
315f1f9cfaSBalaji V. Iyer     VectorType vecType = dyn_cast<VectorType>(resType);
326f5c4f2eSBalaji V. Iyer     // Reject index since getElementTypeBitWidth will abort for Index types.
335f1f9cfaSBalaji V. Iyer     if (!vecType || vecType.getElementType().isIndex())
346f5c4f2eSBalaji V. Iyer       return false;
35ef5a7109SHan-Chung Wang     // There are no dimension to fold if it is a 0-D vector.
36ef5a7109SHan-Chung Wang     if (vecType.getRank() == 0)
37ef5a7109SHan-Chung Wang       return false;
386f5c4f2eSBalaji V. Iyer     unsigned trailingVecDimBitWidth =
396f5c4f2eSBalaji V. Iyer         vecType.getShape().back() * vecType.getElementTypeBitWidth();
406f5c4f2eSBalaji V. Iyer     if (trailingVecDimBitWidth >= targetBitWidth)
416f5c4f2eSBalaji V. Iyer       return false;
426f5c4f2eSBalaji V. Iyer   }
436f5c4f2eSBalaji V. Iyer   return true;
446f5c4f2eSBalaji V. Iyer }
456f5c4f2eSBalaji V. Iyer 
4601fbc565SArtem Kroviakov static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
4701fbc565SArtem Kroviakov   VectorType vecType = dyn_cast<VectorType>(t);
4801fbc565SArtem Kroviakov   // Reject index since getElementTypeBitWidth will abort for Index types.
4901fbc565SArtem Kroviakov   if (!vecType || vecType.getElementType().isIndex())
5001fbc565SArtem Kroviakov     return false;
5101fbc565SArtem Kroviakov   // There are no dimension to fold if it is a 0-D vector.
5201fbc565SArtem Kroviakov   if (vecType.getRank() == 0)
5301fbc565SArtem Kroviakov     return false;
5401fbc565SArtem Kroviakov   unsigned trailingVecDimBitWidth =
5501fbc565SArtem Kroviakov       vecType.getShape().back() * vecType.getElementTypeBitWidth();
5601fbc565SArtem Kroviakov   return trailingVecDimBitWidth <= targetBitWidth;
5701fbc565SArtem Kroviakov }
5801fbc565SArtem Kroviakov 
5935ef3994SIvan Butygin namespace {
6035ef3994SIvan Butygin struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
6135ef3994SIvan Butygin   using OpConversionPattern::OpConversionPattern;
626f5c4f2eSBalaji V. Iyer   LinearizeConstant(
636f5c4f2eSBalaji V. Iyer       const TypeConverter &typeConverter, MLIRContext *context,
646f5c4f2eSBalaji V. Iyer       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
656f5c4f2eSBalaji V. Iyer       PatternBenefit benefit = 1)
666f5c4f2eSBalaji V. Iyer       : OpConversionPattern(typeConverter, context, benefit),
676f5c4f2eSBalaji V. Iyer         targetVectorBitWidth(targetVectBitWidth) {}
6835ef3994SIvan Butygin   LogicalResult
6935ef3994SIvan Butygin   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
7035ef3994SIvan Butygin                   ConversionPatternRewriter &rewriter) const override {
7135ef3994SIvan Butygin     Location loc = constOp.getLoc();
7235ef3994SIvan Butygin     auto resType =
7335ef3994SIvan Butygin         getTypeConverter()->convertType<VectorType>(constOp.getType());
74d3aa92edSAndrzej Warzyński 
75*bd5d361cSChao Chen     if (!resType)
76*bd5d361cSChao Chen       return rewriter.notifyMatchFailure(loc, "can't convert return type");
77*bd5d361cSChao Chen 
78d3aa92edSAndrzej Warzyński     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
79d3aa92edSAndrzej Warzyński       return rewriter.notifyMatchFailure(
80d3aa92edSAndrzej Warzyński           loc,
81d3aa92edSAndrzej Warzyński           "Cannot linearize a constant scalable vector that's not a splat");
82d3aa92edSAndrzej Warzyński 
836f5c4f2eSBalaji V. Iyer     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
846f5c4f2eSBalaji V. Iyer       return rewriter.notifyMatchFailure(
856f5c4f2eSBalaji V. Iyer           loc, "Can't flatten since targetBitWidth <= OpSize");
8635ef3994SIvan Butygin     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
8735ef3994SIvan Butygin     if (!dstElementsAttr)
8835ef3994SIvan Butygin       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
8935ef3994SIvan Butygin 
9035ef3994SIvan Butygin     dstElementsAttr = dstElementsAttr.reshape(resType);
9135ef3994SIvan Butygin     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
9235ef3994SIvan Butygin                                                    dstElementsAttr);
9335ef3994SIvan Butygin     return success();
9435ef3994SIvan Butygin   }
956f5c4f2eSBalaji V. Iyer 
966f5c4f2eSBalaji V. Iyer private:
976f5c4f2eSBalaji V. Iyer   unsigned targetVectorBitWidth;
9835ef3994SIvan Butygin };
9935ef3994SIvan Butygin 
10035ef3994SIvan Butygin struct LinearizeVectorizable final
10135ef3994SIvan Butygin     : OpTraitConversionPattern<OpTrait::Vectorizable> {
10235ef3994SIvan Butygin   using OpTraitConversionPattern::OpTraitConversionPattern;
10335ef3994SIvan Butygin 
1046f5c4f2eSBalaji V. Iyer public:
1056f5c4f2eSBalaji V. Iyer   LinearizeVectorizable(
1066f5c4f2eSBalaji V. Iyer       const TypeConverter &typeConverter, MLIRContext *context,
1076f5c4f2eSBalaji V. Iyer       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
1086f5c4f2eSBalaji V. Iyer       PatternBenefit benefit = 1)
1096f5c4f2eSBalaji V. Iyer       : OpTraitConversionPattern(typeConverter, context, benefit),
1106f5c4f2eSBalaji V. Iyer         targetVectorBitWidth(targetVectBitWidth) {}
11135ef3994SIvan Butygin   LogicalResult
11235ef3994SIvan Butygin   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11335ef3994SIvan Butygin                   ConversionPatternRewriter &rewriter) const override {
1146f5c4f2eSBalaji V. Iyer     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
1156f5c4f2eSBalaji V. Iyer       return rewriter.notifyMatchFailure(
1166f5c4f2eSBalaji V. Iyer           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
11735ef3994SIvan Butygin     FailureOr<Operation *> newOp =
11835ef3994SIvan Butygin         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
11935ef3994SIvan Butygin     if (failed(newOp))
12035ef3994SIvan Butygin       return failure();
12135ef3994SIvan Butygin 
12235ef3994SIvan Butygin     rewriter.replaceOp(op, (*newOp)->getResults());
12335ef3994SIvan Butygin     return success();
12435ef3994SIvan Butygin   }
1256f5c4f2eSBalaji V. Iyer 
1266f5c4f2eSBalaji V. Iyer private:
1276f5c4f2eSBalaji V. Iyer   unsigned targetVectorBitWidth;
12835ef3994SIvan Butygin };
129c577f91dSCharitha Saumya 
130c577f91dSCharitha Saumya /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
131c577f91dSCharitha Saumya /// on a linearized vector.
132c577f91dSCharitha Saumya /// Following,
133c577f91dSCharitha Saumya ///   vector.extract_strided_slice %source
134c577f91dSCharitha Saumya ///         { offsets = [..], strides = [..], sizes = [..] }
135c577f91dSCharitha Saumya /// is converted to :
136c577f91dSCharitha Saumya ///   %source_1d = vector.shape_cast %source
137c577f91dSCharitha Saumya ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
138c577f91dSCharitha Saumya ///   %out_nd = vector.shape_cast %out_1d
139c577f91dSCharitha Saumya /// `shuffle_indices_1d` is computed using the offsets and sizes of the
140c577f91dSCharitha Saumya /// extraction.
141c577f91dSCharitha Saumya struct LinearizeVectorExtractStridedSlice final
142c577f91dSCharitha Saumya     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
143c577f91dSCharitha Saumya   using OpConversionPattern::OpConversionPattern;
144c577f91dSCharitha Saumya   LinearizeVectorExtractStridedSlice(
145c577f91dSCharitha Saumya       const TypeConverter &typeConverter, MLIRContext *context,
146c577f91dSCharitha Saumya       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
147c577f91dSCharitha Saumya       PatternBenefit benefit = 1)
148c577f91dSCharitha Saumya       : OpConversionPattern(typeConverter, context, benefit),
149c577f91dSCharitha Saumya         targetVectorBitWidth(targetVectBitWidth) {}
150c577f91dSCharitha Saumya 
151c577f91dSCharitha Saumya   LogicalResult
152c577f91dSCharitha Saumya   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
153c577f91dSCharitha Saumya                   ConversionPatternRewriter &rewriter) const override {
15474a105adSArtem Kroviakov     VectorType dstType =
15574a105adSArtem Kroviakov         getTypeConverter()->convertType<VectorType>(extractOp.getType());
15674a105adSArtem Kroviakov     assert(dstType && "vector type destination expected.");
15774a105adSArtem Kroviakov     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
15874a105adSArtem Kroviakov       return rewriter.notifyMatchFailure(extractOp,
159c577f91dSCharitha Saumya                                          "scalable vectors are not supported.");
160c577f91dSCharitha Saumya     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
161c577f91dSCharitha Saumya       return rewriter.notifyMatchFailure(
162c577f91dSCharitha Saumya           extractOp, "Can't flatten since targetBitWidth <= OpSize");
163c577f91dSCharitha Saumya 
164c577f91dSCharitha Saumya     ArrayAttr offsets = extractOp.getOffsets();
165c577f91dSCharitha Saumya     ArrayAttr sizes = extractOp.getSizes();
166c577f91dSCharitha Saumya     ArrayAttr strides = extractOp.getStrides();
167c577f91dSCharitha Saumya     if (!isConstantIntValue(strides[0], 1))
168c577f91dSCharitha Saumya       return rewriter.notifyMatchFailure(
169c577f91dSCharitha Saumya           extractOp, "Strided slice with stride != 1 is not supported.");
170c577f91dSCharitha Saumya     Value srcVector = adaptor.getVector();
171c577f91dSCharitha Saumya     // If kD offsets are specified for nD source vector (n > k), the granularity
172c577f91dSCharitha Saumya     // of the extraction is greater than 1. In this case last (n-k) dimensions
173c577f91dSCharitha Saumya     // form the extraction granularity.
174c577f91dSCharitha Saumya     // Example :
175c577f91dSCharitha Saumya     //  vector.extract_strided_slice %src {
176c577f91dSCharitha Saumya     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
177c577f91dSCharitha Saumya     //      vector<4x8x8xf32> to vector<2x2x8xf32>
178c577f91dSCharitha Saumya     // Here, extraction granularity is 8.
179c577f91dSCharitha Saumya     int64_t extractGranularitySize = 1;
180c577f91dSCharitha Saumya     int64_t nD = extractOp.getSourceVectorType().getRank();
181c577f91dSCharitha Saumya     int64_t kD = (int64_t)offsets.size();
182c577f91dSCharitha Saumya     int64_t k = kD;
183c577f91dSCharitha Saumya     while (k < nD) {
184c577f91dSCharitha Saumya       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
185c577f91dSCharitha Saumya       ++k;
186c577f91dSCharitha Saumya     }
187c577f91dSCharitha Saumya     // Get total number of extracted slices.
188c577f91dSCharitha Saumya     int64_t nExtractedSlices = 1;
189c577f91dSCharitha Saumya     for (Attribute size : sizes) {
190fac349a1SChristian Sigg       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
191c577f91dSCharitha Saumya     }
192c577f91dSCharitha Saumya     // Compute the strides of the source vector considering first k dimensions.
193c577f91dSCharitha Saumya     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
194c577f91dSCharitha Saumya     for (int i = kD - 2; i >= 0; --i) {
195c577f91dSCharitha Saumya       sourceStrides[i] = sourceStrides[i + 1] *
196c577f91dSCharitha Saumya                          extractOp.getSourceVectorType().getShape()[i + 1];
197c577f91dSCharitha Saumya     }
198c577f91dSCharitha Saumya     // Final shuffle indices has nExtractedSlices * extractGranularitySize
199c577f91dSCharitha Saumya     // elements.
200c577f91dSCharitha Saumya     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
201c577f91dSCharitha Saumya                                           extractGranularitySize);
202c577f91dSCharitha Saumya     // Compute the strides of the extracted kD vector.
203c577f91dSCharitha Saumya     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
204c577f91dSCharitha Saumya     // Compute extractedStrides.
205c577f91dSCharitha Saumya     for (int i = kD - 2; i >= 0; --i) {
206c577f91dSCharitha Saumya       extractedStrides[i] =
207fac349a1SChristian Sigg           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
208c577f91dSCharitha Saumya     }
209c577f91dSCharitha Saumya     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
210c577f91dSCharitha Saumya     // and compute the multi-dimensional index and the corresponding linearized
211c577f91dSCharitha Saumya     // index within the source vector.
212c577f91dSCharitha Saumya     for (int64_t i = 0; i < nExtractedSlices; ++i) {
213c577f91dSCharitha Saumya       int64_t index = i;
214c577f91dSCharitha Saumya       // Compute the corresponding multi-dimensional index.
215c577f91dSCharitha Saumya       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
216c577f91dSCharitha Saumya       for (int64_t j = 0; j < kD; ++j) {
217c577f91dSCharitha Saumya         multiDimIndex[j] = (index / extractedStrides[j]);
218c577f91dSCharitha Saumya         index -= multiDimIndex[j] * extractedStrides[j];
219c577f91dSCharitha Saumya       }
220c577f91dSCharitha Saumya       // Compute the corresponding linearized index in the source vector
221c577f91dSCharitha Saumya       // i.e. shift the multiDimIndex by the offsets.
222c577f91dSCharitha Saumya       int64_t linearizedIndex = 0;
223c577f91dSCharitha Saumya       for (int64_t j = 0; j < kD; ++j) {
224c577f91dSCharitha Saumya         linearizedIndex +=
225fac349a1SChristian Sigg             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
226c577f91dSCharitha Saumya             sourceStrides[j];
227c577f91dSCharitha Saumya       }
228c577f91dSCharitha Saumya       // Fill the indices array form linearizedIndex to linearizedIndex +
229c577f91dSCharitha Saumya       // extractGranularitySize.
230c577f91dSCharitha Saumya       for (int64_t j = 0; j < extractGranularitySize; ++j) {
231c577f91dSCharitha Saumya         indices[i * extractGranularitySize + j] = linearizedIndex + j;
232c577f91dSCharitha Saumya       }
233c577f91dSCharitha Saumya     }
234c577f91dSCharitha Saumya     // Perform a shuffle to extract the kD vector.
235c577f91dSCharitha Saumya     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
236b4444dcaSBenjamin Maxwell         extractOp, dstType, srcVector, srcVector, indices);
237c577f91dSCharitha Saumya     return success();
238c577f91dSCharitha Saumya   }
239c577f91dSCharitha Saumya 
240c577f91dSCharitha Saumya private:
241c577f91dSCharitha Saumya   unsigned targetVectorBitWidth;
242c577f91dSCharitha Saumya };
243c577f91dSCharitha Saumya 
244c577f91dSCharitha Saumya /// This pattern converts the ShuffleOp that works on nD (n > 1)
245c577f91dSCharitha Saumya /// vectors to a ShuffleOp that works on linearized vectors.
246c577f91dSCharitha Saumya /// Following,
247c577f91dSCharitha Saumya ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
248c577f91dSCharitha Saumya /// is converted to :
249c577f91dSCharitha Saumya ///   %v1_1d = vector.shape_cast %v1
250c577f91dSCharitha Saumya ///   %v2_1d = vector.shape_cast %v2
251c577f91dSCharitha Saumya ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
252c577f91dSCharitha Saumya ///   %out_nd = vector.shape_cast %out_1d
253c577f91dSCharitha Saumya // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
254c577f91dSCharitha Saumya /// of the original shuffle operation.
255c577f91dSCharitha Saumya struct LinearizeVectorShuffle final
256c577f91dSCharitha Saumya     : public OpConversionPattern<vector::ShuffleOp> {
257c577f91dSCharitha Saumya   using OpConversionPattern::OpConversionPattern;
258c577f91dSCharitha Saumya   LinearizeVectorShuffle(
259c577f91dSCharitha Saumya       const TypeConverter &typeConverter, MLIRContext *context,
260c577f91dSCharitha Saumya       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
261c577f91dSCharitha Saumya       PatternBenefit benefit = 1)
262c577f91dSCharitha Saumya       : OpConversionPattern(typeConverter, context, benefit),
263c577f91dSCharitha Saumya         targetVectorBitWidth(targetVectBitWidth) {}
264c577f91dSCharitha Saumya 
265c577f91dSCharitha Saumya   LogicalResult
266c577f91dSCharitha Saumya   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
267c577f91dSCharitha Saumya                   ConversionPatternRewriter &rewriter) const override {
26874a105adSArtem Kroviakov     VectorType dstType =
26974a105adSArtem Kroviakov         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
27074a105adSArtem Kroviakov     assert(dstType && "vector type destination expected.");
27174a105adSArtem Kroviakov     // The assert is used because vector.shuffle does not support scalable
27274a105adSArtem Kroviakov     // vectors.
273c577f91dSCharitha Saumya     assert(!(shuffleOp.getV1VectorType().isScalable() ||
274c577f91dSCharitha Saumya              shuffleOp.getV2VectorType().isScalable() ||
27574a105adSArtem Kroviakov              dstType.isScalable()) &&
276c577f91dSCharitha Saumya            "scalable vectors are not supported.");
277c577f91dSCharitha Saumya     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
278c577f91dSCharitha Saumya       return rewriter.notifyMatchFailure(
279c577f91dSCharitha Saumya           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
280c577f91dSCharitha Saumya 
281c577f91dSCharitha Saumya     Value vec1 = adaptor.getV1();
282c577f91dSCharitha Saumya     Value vec2 = adaptor.getV2();
283c577f91dSCharitha Saumya     int shuffleSliceLen = 1;
284c577f91dSCharitha Saumya     int rank = shuffleOp.getV1().getType().getRank();
285c577f91dSCharitha Saumya 
286c577f91dSCharitha Saumya     // If rank > 1, we need to do the shuffle in the granularity of slices
287c577f91dSCharitha Saumya     // instead of scalars. Size of the slice is equal to the rank-1 innermost
288c577f91dSCharitha Saumya     // dims. Mask of the shuffle op specifies which slice to take from the
289c577f91dSCharitha Saumya     // outermost dim.
290c577f91dSCharitha Saumya     if (rank > 1) {
291c577f91dSCharitha Saumya       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
292c577f91dSCharitha Saumya       for (unsigned i = 1; i < shape.size(); ++i) {
293c577f91dSCharitha Saumya         shuffleSliceLen *= shape[i];
294c577f91dSCharitha Saumya       }
295c577f91dSCharitha Saumya     }
296c577f91dSCharitha Saumya 
297c577f91dSCharitha Saumya     // For each value in the mask, we generate the indices of the source vectors
298c577f91dSCharitha Saumya     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
299c577f91dSCharitha Saumya     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
300c577f91dSCharitha Saumya     // elements) instead of scalars.
301b4444dcaSBenjamin Maxwell     ArrayRef<int64_t> mask = shuffleOp.getMask();
302c577f91dSCharitha Saumya     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
303c577f91dSCharitha Saumya     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
304b4444dcaSBenjamin Maxwell     for (auto [i, value] : llvm::enumerate(mask)) {
305c577f91dSCharitha Saumya       std::iota(indices.begin() + shuffleSliceLen * i,
306c577f91dSCharitha Saumya                 indices.begin() + shuffleSliceLen * (i + 1),
307b4444dcaSBenjamin Maxwell                 shuffleSliceLen * value);
308c577f91dSCharitha Saumya     }
309c577f91dSCharitha Saumya 
310b4444dcaSBenjamin Maxwell     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
311b4444dcaSBenjamin Maxwell                                                    vec2, indices);
312c577f91dSCharitha Saumya     return success();
313c577f91dSCharitha Saumya   }
314c577f91dSCharitha Saumya 
315c577f91dSCharitha Saumya private:
316c577f91dSCharitha Saumya   unsigned targetVectorBitWidth;
317c577f91dSCharitha Saumya };
318c577f91dSCharitha Saumya 
319c577f91dSCharitha Saumya /// This pattern converts the ExtractOp to a ShuffleOp that works on a
320c577f91dSCharitha Saumya /// linearized vector.
321c577f91dSCharitha Saumya /// Following,
322c577f91dSCharitha Saumya ///   vector.extract %source [ position ]
323c577f91dSCharitha Saumya /// is converted to :
324c577f91dSCharitha Saumya ///   %source_1d = vector.shape_cast %source
325c577f91dSCharitha Saumya ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
326c577f91dSCharitha Saumya ///   %out_nd = vector.shape_cast %out_1d
327c577f91dSCharitha Saumya /// `shuffle_indices_1d` is computed using the position of the original extract.
328c577f91dSCharitha Saumya struct LinearizeVectorExtract final
329c577f91dSCharitha Saumya     : public OpConversionPattern<vector::ExtractOp> {
330c577f91dSCharitha Saumya   using OpConversionPattern::OpConversionPattern;
331c577f91dSCharitha Saumya   LinearizeVectorExtract(
332c577f91dSCharitha Saumya       const TypeConverter &typeConverter, MLIRContext *context,
333c577f91dSCharitha Saumya       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
334c577f91dSCharitha Saumya       PatternBenefit benefit = 1)
335c577f91dSCharitha Saumya       : OpConversionPattern(typeConverter, context, benefit),
336c577f91dSCharitha Saumya         targetVectorBitWidth(targetVectBitWidth) {}
337c577f91dSCharitha Saumya   LogicalResult
338c577f91dSCharitha Saumya   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
339c577f91dSCharitha Saumya                   ConversionPatternRewriter &rewriter) const override {
340c577f91dSCharitha Saumya     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
34150febdebSLongsheng Mou     if (!dstTy)
34250febdebSLongsheng Mou       return rewriter.notifyMatchFailure(extractOp,
34350febdebSLongsheng Mou                                          "expected n-D vector type.");
34450febdebSLongsheng Mou 
34574a105adSArtem Kroviakov     if (extractOp.getVector().getType().isScalable() ||
34674a105adSArtem Kroviakov         cast<VectorType>(dstTy).isScalable())
34774a105adSArtem Kroviakov       return rewriter.notifyMatchFailure(extractOp,
348c577f91dSCharitha Saumya                                          "scalable vectors are not supported.");
349c577f91dSCharitha Saumya     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
350c577f91dSCharitha Saumya       return rewriter.notifyMatchFailure(
351c577f91dSCharitha Saumya           extractOp, "Can't flatten since targetBitWidth <= OpSize");
352c577f91dSCharitha Saumya 
353c577f91dSCharitha Saumya     // Dynamic position is not supported.
354c577f91dSCharitha Saumya     if (extractOp.hasDynamicPosition())
355c577f91dSCharitha Saumya       return rewriter.notifyMatchFailure(extractOp,
356c577f91dSCharitha Saumya                                          "dynamic position is not supported.");
357c577f91dSCharitha Saumya 
358c577f91dSCharitha Saumya     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
359c577f91dSCharitha Saumya     int64_t size = extractOp.getVector().getType().getNumElements();
360c577f91dSCharitha Saumya 
361c577f91dSCharitha Saumya     // Compute linearized offset.
362c577f91dSCharitha Saumya     int64_t linearizedOffset = 0;
363c577f91dSCharitha Saumya     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
364c577f91dSCharitha Saumya     for (auto [i, off] : llvm::enumerate(offsets)) {
365c577f91dSCharitha Saumya       size /= shape[i];
366c577f91dSCharitha Saumya       linearizedOffset += offsets[i] * size;
367c577f91dSCharitha Saumya     }
368c577f91dSCharitha Saumya 
369c577f91dSCharitha Saumya     llvm::SmallVector<int64_t, 2> indices(size);
370c577f91dSCharitha Saumya     std::iota(indices.begin(), indices.end(), linearizedOffset);
371c577f91dSCharitha Saumya     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
372b4444dcaSBenjamin Maxwell         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
373c577f91dSCharitha Saumya 
374c577f91dSCharitha Saumya     return success();
375c577f91dSCharitha Saumya   }
376c577f91dSCharitha Saumya 
377c577f91dSCharitha Saumya private:
378c577f91dSCharitha Saumya   unsigned targetVectorBitWidth;
379c577f91dSCharitha Saumya };
38001fbc565SArtem Kroviakov 
38101fbc565SArtem Kroviakov /// This pattern converts the InsertOp to a ShuffleOp that works on a
38201fbc565SArtem Kroviakov /// linearized vector.
38301fbc565SArtem Kroviakov /// Following,
38401fbc565SArtem Kroviakov ///   vector.insert %source %destination [ position ]
38501fbc565SArtem Kroviakov /// is converted to :
38601fbc565SArtem Kroviakov ///   %source_1d = vector.shape_cast %source
38701fbc565SArtem Kroviakov ///   %destination_1d = vector.shape_cast %destination
38801fbc565SArtem Kroviakov ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
38901fbc565SArtem Kroviakov ///   ] %out_nd = vector.shape_cast %out_1d
39001fbc565SArtem Kroviakov /// `shuffle_indices_1d` is computed using the position of the original insert.
39101fbc565SArtem Kroviakov struct LinearizeVectorInsert final
39201fbc565SArtem Kroviakov     : public OpConversionPattern<vector::InsertOp> {
39301fbc565SArtem Kroviakov   using OpConversionPattern::OpConversionPattern;
39401fbc565SArtem Kroviakov   LinearizeVectorInsert(
39501fbc565SArtem Kroviakov       const TypeConverter &typeConverter, MLIRContext *context,
39601fbc565SArtem Kroviakov       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
39701fbc565SArtem Kroviakov       PatternBenefit benefit = 1)
39801fbc565SArtem Kroviakov       : OpConversionPattern(typeConverter, context, benefit),
39901fbc565SArtem Kroviakov         targetVectorBitWidth(targetVectBitWidth) {}
40001fbc565SArtem Kroviakov   LogicalResult
40101fbc565SArtem Kroviakov   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
40201fbc565SArtem Kroviakov                   ConversionPatternRewriter &rewriter) const override {
40374a105adSArtem Kroviakov     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
40474a105adSArtem Kroviakov         insertOp.getDestVectorType());
40574a105adSArtem Kroviakov     assert(dstTy && "vector type destination expected.");
40674a105adSArtem Kroviakov     if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
40774a105adSArtem Kroviakov       return rewriter.notifyMatchFailure(insertOp,
40801fbc565SArtem Kroviakov                                          "scalable vectors are not supported.");
40901fbc565SArtem Kroviakov 
41001fbc565SArtem Kroviakov     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
41101fbc565SArtem Kroviakov                                          targetVectorBitWidth))
41201fbc565SArtem Kroviakov       return rewriter.notifyMatchFailure(
41301fbc565SArtem Kroviakov           insertOp, "Can't flatten since targetBitWidth < OpSize");
41401fbc565SArtem Kroviakov 
41501fbc565SArtem Kroviakov     // dynamic position is not supported
41601fbc565SArtem Kroviakov     if (insertOp.hasDynamicPosition())
41701fbc565SArtem Kroviakov       return rewriter.notifyMatchFailure(insertOp,
41801fbc565SArtem Kroviakov                                          "dynamic position is not supported.");
41901fbc565SArtem Kroviakov     auto srcTy = insertOp.getSourceType();
42001fbc565SArtem Kroviakov     auto srcAsVec = dyn_cast<VectorType>(srcTy);
42101fbc565SArtem Kroviakov     uint64_t srcSize = 0;
42201fbc565SArtem Kroviakov     if (srcAsVec) {
42301fbc565SArtem Kroviakov       srcSize = srcAsVec.getNumElements();
42401fbc565SArtem Kroviakov     } else {
42501fbc565SArtem Kroviakov       return rewriter.notifyMatchFailure(insertOp,
42601fbc565SArtem Kroviakov                                          "scalars are not supported.");
42701fbc565SArtem Kroviakov     }
42801fbc565SArtem Kroviakov 
42901fbc565SArtem Kroviakov     auto dstShape = insertOp.getDestVectorType().getShape();
43001fbc565SArtem Kroviakov     const auto dstSize = insertOp.getDestVectorType().getNumElements();
43101fbc565SArtem Kroviakov     auto dstSizeForOffsets = dstSize;
43201fbc565SArtem Kroviakov 
43301fbc565SArtem Kroviakov     // compute linearized offset
43401fbc565SArtem Kroviakov     int64_t linearizedOffset = 0;
43501fbc565SArtem Kroviakov     auto offsetsNd = insertOp.getStaticPosition();
43601fbc565SArtem Kroviakov     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
43701fbc565SArtem Kroviakov       dstSizeForOffsets /= dstShape[dim];
43801fbc565SArtem Kroviakov       linearizedOffset += offset * dstSizeForOffsets;
43901fbc565SArtem Kroviakov     }
44001fbc565SArtem Kroviakov 
44101fbc565SArtem Kroviakov     llvm::SmallVector<int64_t, 2> indices(dstSize);
44201fbc565SArtem Kroviakov     auto origValsUntil = indices.begin();
44301fbc565SArtem Kroviakov     std::advance(origValsUntil, linearizedOffset);
44401fbc565SArtem Kroviakov     std::iota(indices.begin(), origValsUntil,
44501fbc565SArtem Kroviakov               0); // original values that remain [0, offset)
44601fbc565SArtem Kroviakov     auto newValsUntil = origValsUntil;
44701fbc565SArtem Kroviakov     std::advance(newValsUntil, srcSize);
44801fbc565SArtem Kroviakov     std::iota(origValsUntil, newValsUntil,
44901fbc565SArtem Kroviakov               dstSize); // new values [offset, offset+srcNumElements)
45001fbc565SArtem Kroviakov     std::iota(newValsUntil, indices.end(),
45101fbc565SArtem Kroviakov               linearizedOffset + srcSize); // the rest of original values
45201fbc565SArtem Kroviakov                                            // [offset+srcNumElements, end)
45301fbc565SArtem Kroviakov 
45401fbc565SArtem Kroviakov     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
455b4444dcaSBenjamin Maxwell         insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
45601fbc565SArtem Kroviakov 
45701fbc565SArtem Kroviakov     return success();
45801fbc565SArtem Kroviakov   }
45901fbc565SArtem Kroviakov 
46001fbc565SArtem Kroviakov private:
46101fbc565SArtem Kroviakov   unsigned targetVectorBitWidth;
46201fbc565SArtem Kroviakov };
463*bd5d361cSChao Chen 
464*bd5d361cSChao Chen /// This pattern converts the BitCastOp that works on nD (n > 1)
465*bd5d361cSChao Chen /// vectors to a BitCastOp that works on linearized vectors.
466*bd5d361cSChao Chen /// Following,
467*bd5d361cSChao Chen ///   vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
468*bd5d361cSChao Chen /// is converted to :
469*bd5d361cSChao Chen ///   %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
470*bd5d361cSChao Chen ///   %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
471*bd5d361cSChao Chen ///   %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
472*bd5d361cSChao Chen struct LinearizeVectorBitCast final
473*bd5d361cSChao Chen     : public OpConversionPattern<vector::BitCastOp> {
474*bd5d361cSChao Chen   using OpConversionPattern::OpConversionPattern;
475*bd5d361cSChao Chen   LinearizeVectorBitCast(
476*bd5d361cSChao Chen       const TypeConverter &typeConverter, MLIRContext *context,
477*bd5d361cSChao Chen       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
478*bd5d361cSChao Chen       PatternBenefit benefit = 1)
479*bd5d361cSChao Chen       : OpConversionPattern(typeConverter, context, benefit),
480*bd5d361cSChao Chen         targetVectorBitWidth(targetVectBitWidth) {}
481*bd5d361cSChao Chen   LogicalResult
482*bd5d361cSChao Chen   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
483*bd5d361cSChao Chen                   ConversionPatternRewriter &rewriter) const override {
484*bd5d361cSChao Chen     Location loc = castOp.getLoc();
485*bd5d361cSChao Chen     auto resType = getTypeConverter()->convertType(castOp.getType());
486*bd5d361cSChao Chen     if (!resType)
487*bd5d361cSChao Chen       return rewriter.notifyMatchFailure(loc, "can't convert return type.");
488*bd5d361cSChao Chen 
489*bd5d361cSChao Chen     if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
490*bd5d361cSChao Chen       return rewriter.notifyMatchFailure(
491*bd5d361cSChao Chen           loc, "Can't flatten since targetBitWidth <= OpSize");
492*bd5d361cSChao Chen 
493*bd5d361cSChao Chen     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
494*bd5d361cSChao Chen                                                    adaptor.getSource());
495*bd5d361cSChao Chen     return mlir::success();
496*bd5d361cSChao Chen   }
497*bd5d361cSChao Chen 
498*bd5d361cSChao Chen private:
499*bd5d361cSChao Chen   unsigned targetVectorBitWidth;
500*bd5d361cSChao Chen };
501*bd5d361cSChao Chen 
50235ef3994SIvan Butygin } // namespace
50335ef3994SIvan Butygin 
50435ef3994SIvan Butygin void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
50535ef3994SIvan Butygin     TypeConverter &typeConverter, RewritePatternSet &patterns,
5066f5c4f2eSBalaji V. Iyer     ConversionTarget &target, unsigned targetBitWidth) {
5076f5c4f2eSBalaji V. Iyer 
50835ef3994SIvan Butygin   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
509d3aa92edSAndrzej Warzyński     if (!isLinearizableVector(type))
51035ef3994SIvan Butygin       return type;
51135ef3994SIvan Butygin 
512d3aa92edSAndrzej Warzyński     return VectorType::get(type.getNumElements(), type.getElementType(),
513d3aa92edSAndrzej Warzyński                            type.isScalable());
51435ef3994SIvan Butygin   });
51535ef3994SIvan Butygin 
51635ef3994SIvan Butygin   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
51735ef3994SIvan Butygin                             Location loc) -> Value {
51835ef3994SIvan Butygin     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
51935ef3994SIvan Butygin         !isa<VectorType>(type))
52035ef3994SIvan Butygin       return nullptr;
52135ef3994SIvan Butygin 
52235ef3994SIvan Butygin     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
52335ef3994SIvan Butygin   };
52435ef3994SIvan Butygin   typeConverter.addSourceMaterialization(materializeCast);
52535ef3994SIvan Butygin   typeConverter.addTargetMaterialization(materializeCast);
52635ef3994SIvan Butygin   target.markUnknownOpDynamicallyLegal(
5276f5c4f2eSBalaji V. Iyer       [=](Operation *op) -> std::optional<bool> {
528*bd5d361cSChao Chen         if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
5296f5c4f2eSBalaji V. Iyer              op->hasTrait<OpTrait::Vectorizable>())) {
5306f5c4f2eSBalaji V. Iyer           return (isLessThanTargetBitWidth(op, targetBitWidth)
5316f5c4f2eSBalaji V. Iyer                       ? typeConverter.isLegal(op)
5326f5c4f2eSBalaji V. Iyer                       : true);
5336f5c4f2eSBalaji V. Iyer         }
53435ef3994SIvan Butygin         return std::nullopt;
53535ef3994SIvan Butygin       });
53635ef3994SIvan Butygin 
537*bd5d361cSChao Chen   patterns
538*bd5d361cSChao Chen       .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
5396f5c4f2eSBalaji V. Iyer           typeConverter, patterns.getContext(), targetBitWidth);
54035ef3994SIvan Butygin }
541c577f91dSCharitha Saumya 
542c577f91dSCharitha Saumya void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
543206fad0eSMatthias Springer     const TypeConverter &typeConverter, RewritePatternSet &patterns,
544c577f91dSCharitha Saumya     ConversionTarget &target, unsigned int targetBitWidth) {
545c577f91dSCharitha Saumya   target.addDynamicallyLegalOp<vector::ShuffleOp>(
546c577f91dSCharitha Saumya       [=](vector::ShuffleOp shuffleOp) -> bool {
547c577f91dSCharitha Saumya         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
548c577f91dSCharitha Saumya                    ? (typeConverter.isLegal(shuffleOp) &&
549fac349a1SChristian Sigg                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
550c577f91dSCharitha Saumya                               .getRank() == 1)
551c577f91dSCharitha Saumya                    : true;
552c577f91dSCharitha Saumya       });
553c577f91dSCharitha Saumya   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
55401fbc565SArtem Kroviakov                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
555c577f91dSCharitha Saumya       typeConverter, patterns.getContext(), targetBitWidth);
556c577f91dSCharitha Saumya }
557