135ef3994SIvan Butygin //===- VectorLinearize.cpp - vector linearization transforms --------------===// 235ef3994SIvan Butygin // 335ef3994SIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 435ef3994SIvan Butygin // See https://llvm.org/LICENSE.txt for license information. 535ef3994SIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 635ef3994SIvan Butygin // 735ef3994SIvan Butygin //===----------------------------------------------------------------------===// 835ef3994SIvan Butygin // 935ef3994SIvan Butygin // This file implements patterns and pass for linearizing ND vectors into 1D. 1035ef3994SIvan Butygin // 1135ef3994SIvan Butygin //===----------------------------------------------------------------------===// 1235ef3994SIvan Butygin 1335ef3994SIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h" 1435ef3994SIvan Butygin #include "mlir/Dialect/Vector/IR/VectorOps.h" 1535ef3994SIvan Butygin #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16c577f91dSCharitha Saumya #include "mlir/IR/Attributes.h" 17c577f91dSCharitha Saumya #include "mlir/IR/BuiltinAttributes.h" 18c577f91dSCharitha Saumya #include "mlir/IR/Operation.h" 1935ef3994SIvan Butygin #include "mlir/IR/PatternMatch.h" 2035ef3994SIvan Butygin #include "mlir/IR/TypeUtilities.h" 2135ef3994SIvan Butygin #include "mlir/Transforms/DialectConversion.h" 22c577f91dSCharitha Saumya #include "llvm/ADT/ArrayRef.h" 23c577f91dSCharitha Saumya #include <cstdint> 24c577f91dSCharitha Saumya #include <numeric> 2535ef3994SIvan Butygin 2635ef3994SIvan Butygin using namespace mlir; 2735ef3994SIvan Butygin 286f5c4f2eSBalaji V. Iyer static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { 296f5c4f2eSBalaji V. Iyer auto resultTypes = op->getResultTypes(); 306f5c4f2eSBalaji V. Iyer for (auto resType : resultTypes) { 315f1f9cfaSBalaji V. Iyer VectorType vecType = dyn_cast<VectorType>(resType); 326f5c4f2eSBalaji V. Iyer // Reject index since getElementTypeBitWidth will abort for Index types. 335f1f9cfaSBalaji V. Iyer if (!vecType || vecType.getElementType().isIndex()) 346f5c4f2eSBalaji V. Iyer return false; 35ef5a7109SHan-Chung Wang // There are no dimension to fold if it is a 0-D vector. 36ef5a7109SHan-Chung Wang if (vecType.getRank() == 0) 37ef5a7109SHan-Chung Wang return false; 386f5c4f2eSBalaji V. Iyer unsigned trailingVecDimBitWidth = 396f5c4f2eSBalaji V. Iyer vecType.getShape().back() * vecType.getElementTypeBitWidth(); 406f5c4f2eSBalaji V. Iyer if (trailingVecDimBitWidth >= targetBitWidth) 416f5c4f2eSBalaji V. Iyer return false; 426f5c4f2eSBalaji V. Iyer } 436f5c4f2eSBalaji V. Iyer return true; 446f5c4f2eSBalaji V. Iyer } 456f5c4f2eSBalaji V. Iyer 4601fbc565SArtem Kroviakov static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { 4701fbc565SArtem Kroviakov VectorType vecType = dyn_cast<VectorType>(t); 4801fbc565SArtem Kroviakov // Reject index since getElementTypeBitWidth will abort for Index types. 4901fbc565SArtem Kroviakov if (!vecType || vecType.getElementType().isIndex()) 5001fbc565SArtem Kroviakov return false; 5101fbc565SArtem Kroviakov // There are no dimension to fold if it is a 0-D vector. 5201fbc565SArtem Kroviakov if (vecType.getRank() == 0) 5301fbc565SArtem Kroviakov return false; 5401fbc565SArtem Kroviakov unsigned trailingVecDimBitWidth = 5501fbc565SArtem Kroviakov vecType.getShape().back() * vecType.getElementTypeBitWidth(); 5601fbc565SArtem Kroviakov return trailingVecDimBitWidth <= targetBitWidth; 5701fbc565SArtem Kroviakov } 5801fbc565SArtem Kroviakov 5935ef3994SIvan Butygin namespace { 6035ef3994SIvan Butygin struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 6135ef3994SIvan Butygin using OpConversionPattern::OpConversionPattern; 626f5c4f2eSBalaji V. Iyer LinearizeConstant( 636f5c4f2eSBalaji V. Iyer const TypeConverter &typeConverter, MLIRContext *context, 646f5c4f2eSBalaji V. Iyer unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 656f5c4f2eSBalaji V. Iyer PatternBenefit benefit = 1) 666f5c4f2eSBalaji V. Iyer : OpConversionPattern(typeConverter, context, benefit), 676f5c4f2eSBalaji V. Iyer targetVectorBitWidth(targetVectBitWidth) {} 6835ef3994SIvan Butygin LogicalResult 6935ef3994SIvan Butygin matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 7035ef3994SIvan Butygin ConversionPatternRewriter &rewriter) const override { 7135ef3994SIvan Butygin Location loc = constOp.getLoc(); 7235ef3994SIvan Butygin auto resType = 7335ef3994SIvan Butygin getTypeConverter()->convertType<VectorType>(constOp.getType()); 74d3aa92edSAndrzej Warzyński 75*bd5d361cSChao Chen if (!resType) 76*bd5d361cSChao Chen return rewriter.notifyMatchFailure(loc, "can't convert return type"); 77*bd5d361cSChao Chen 78d3aa92edSAndrzej Warzyński if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue())) 79d3aa92edSAndrzej Warzyński return rewriter.notifyMatchFailure( 80d3aa92edSAndrzej Warzyński loc, 81d3aa92edSAndrzej Warzyński "Cannot linearize a constant scalable vector that's not a splat"); 82d3aa92edSAndrzej Warzyński 836f5c4f2eSBalaji V. Iyer if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) 846f5c4f2eSBalaji V. Iyer return rewriter.notifyMatchFailure( 856f5c4f2eSBalaji V. Iyer loc, "Can't flatten since targetBitWidth <= OpSize"); 8635ef3994SIvan Butygin auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 8735ef3994SIvan Butygin if (!dstElementsAttr) 8835ef3994SIvan Butygin return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 8935ef3994SIvan Butygin 9035ef3994SIvan Butygin dstElementsAttr = dstElementsAttr.reshape(resType); 9135ef3994SIvan Butygin rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 9235ef3994SIvan Butygin dstElementsAttr); 9335ef3994SIvan Butygin return success(); 9435ef3994SIvan Butygin } 956f5c4f2eSBalaji V. Iyer 966f5c4f2eSBalaji V. Iyer private: 976f5c4f2eSBalaji V. Iyer unsigned targetVectorBitWidth; 9835ef3994SIvan Butygin }; 9935ef3994SIvan Butygin 10035ef3994SIvan Butygin struct LinearizeVectorizable final 10135ef3994SIvan Butygin : OpTraitConversionPattern<OpTrait::Vectorizable> { 10235ef3994SIvan Butygin using OpTraitConversionPattern::OpTraitConversionPattern; 10335ef3994SIvan Butygin 1046f5c4f2eSBalaji V. Iyer public: 1056f5c4f2eSBalaji V. Iyer LinearizeVectorizable( 1066f5c4f2eSBalaji V. Iyer const TypeConverter &typeConverter, MLIRContext *context, 1076f5c4f2eSBalaji V. Iyer unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 1086f5c4f2eSBalaji V. Iyer PatternBenefit benefit = 1) 1096f5c4f2eSBalaji V. Iyer : OpTraitConversionPattern(typeConverter, context, benefit), 1106f5c4f2eSBalaji V. Iyer targetVectorBitWidth(targetVectBitWidth) {} 11135ef3994SIvan Butygin LogicalResult 11235ef3994SIvan Butygin matchAndRewrite(Operation *op, ArrayRef<Value> operands, 11335ef3994SIvan Butygin ConversionPatternRewriter &rewriter) const override { 1146f5c4f2eSBalaji V. Iyer if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) 1156f5c4f2eSBalaji V. Iyer return rewriter.notifyMatchFailure( 1166f5c4f2eSBalaji V. Iyer op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); 11735ef3994SIvan Butygin FailureOr<Operation *> newOp = 11835ef3994SIvan Butygin convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 11935ef3994SIvan Butygin if (failed(newOp)) 12035ef3994SIvan Butygin return failure(); 12135ef3994SIvan Butygin 12235ef3994SIvan Butygin rewriter.replaceOp(op, (*newOp)->getResults()); 12335ef3994SIvan Butygin return success(); 12435ef3994SIvan Butygin } 1256f5c4f2eSBalaji V. Iyer 1266f5c4f2eSBalaji V. Iyer private: 1276f5c4f2eSBalaji V. Iyer unsigned targetVectorBitWidth; 12835ef3994SIvan Butygin }; 129c577f91dSCharitha Saumya 130c577f91dSCharitha Saumya /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works 131c577f91dSCharitha Saumya /// on a linearized vector. 132c577f91dSCharitha Saumya /// Following, 133c577f91dSCharitha Saumya /// vector.extract_strided_slice %source 134c577f91dSCharitha Saumya /// { offsets = [..], strides = [..], sizes = [..] } 135c577f91dSCharitha Saumya /// is converted to : 136c577f91dSCharitha Saumya /// %source_1d = vector.shape_cast %source 137c577f91dSCharitha Saumya /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 138c577f91dSCharitha Saumya /// %out_nd = vector.shape_cast %out_1d 139c577f91dSCharitha Saumya /// `shuffle_indices_1d` is computed using the offsets and sizes of the 140c577f91dSCharitha Saumya /// extraction. 141c577f91dSCharitha Saumya struct LinearizeVectorExtractStridedSlice final 142c577f91dSCharitha Saumya : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { 143c577f91dSCharitha Saumya using OpConversionPattern::OpConversionPattern; 144c577f91dSCharitha Saumya LinearizeVectorExtractStridedSlice( 145c577f91dSCharitha Saumya const TypeConverter &typeConverter, MLIRContext *context, 146c577f91dSCharitha Saumya unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 147c577f91dSCharitha Saumya PatternBenefit benefit = 1) 148c577f91dSCharitha Saumya : OpConversionPattern(typeConverter, context, benefit), 149c577f91dSCharitha Saumya targetVectorBitWidth(targetVectBitWidth) {} 150c577f91dSCharitha Saumya 151c577f91dSCharitha Saumya LogicalResult 152c577f91dSCharitha Saumya matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 153c577f91dSCharitha Saumya ConversionPatternRewriter &rewriter) const override { 15474a105adSArtem Kroviakov VectorType dstType = 15574a105adSArtem Kroviakov getTypeConverter()->convertType<VectorType>(extractOp.getType()); 15674a105adSArtem Kroviakov assert(dstType && "vector type destination expected."); 15774a105adSArtem Kroviakov if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) 15874a105adSArtem Kroviakov return rewriter.notifyMatchFailure(extractOp, 159c577f91dSCharitha Saumya "scalable vectors are not supported."); 160c577f91dSCharitha Saumya if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 161c577f91dSCharitha Saumya return rewriter.notifyMatchFailure( 162c577f91dSCharitha Saumya extractOp, "Can't flatten since targetBitWidth <= OpSize"); 163c577f91dSCharitha Saumya 164c577f91dSCharitha Saumya ArrayAttr offsets = extractOp.getOffsets(); 165c577f91dSCharitha Saumya ArrayAttr sizes = extractOp.getSizes(); 166c577f91dSCharitha Saumya ArrayAttr strides = extractOp.getStrides(); 167c577f91dSCharitha Saumya if (!isConstantIntValue(strides[0], 1)) 168c577f91dSCharitha Saumya return rewriter.notifyMatchFailure( 169c577f91dSCharitha Saumya extractOp, "Strided slice with stride != 1 is not supported."); 170c577f91dSCharitha Saumya Value srcVector = adaptor.getVector(); 171c577f91dSCharitha Saumya // If kD offsets are specified for nD source vector (n > k), the granularity 172c577f91dSCharitha Saumya // of the extraction is greater than 1. In this case last (n-k) dimensions 173c577f91dSCharitha Saumya // form the extraction granularity. 174c577f91dSCharitha Saumya // Example : 175c577f91dSCharitha Saumya // vector.extract_strided_slice %src { 176c577f91dSCharitha Saumya // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : 177c577f91dSCharitha Saumya // vector<4x8x8xf32> to vector<2x2x8xf32> 178c577f91dSCharitha Saumya // Here, extraction granularity is 8. 179c577f91dSCharitha Saumya int64_t extractGranularitySize = 1; 180c577f91dSCharitha Saumya int64_t nD = extractOp.getSourceVectorType().getRank(); 181c577f91dSCharitha Saumya int64_t kD = (int64_t)offsets.size(); 182c577f91dSCharitha Saumya int64_t k = kD; 183c577f91dSCharitha Saumya while (k < nD) { 184c577f91dSCharitha Saumya extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; 185c577f91dSCharitha Saumya ++k; 186c577f91dSCharitha Saumya } 187c577f91dSCharitha Saumya // Get total number of extracted slices. 188c577f91dSCharitha Saumya int64_t nExtractedSlices = 1; 189c577f91dSCharitha Saumya for (Attribute size : sizes) { 190fac349a1SChristian Sigg nExtractedSlices *= cast<IntegerAttr>(size).getInt(); 191c577f91dSCharitha Saumya } 192c577f91dSCharitha Saumya // Compute the strides of the source vector considering first k dimensions. 193c577f91dSCharitha Saumya llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); 194c577f91dSCharitha Saumya for (int i = kD - 2; i >= 0; --i) { 195c577f91dSCharitha Saumya sourceStrides[i] = sourceStrides[i + 1] * 196c577f91dSCharitha Saumya extractOp.getSourceVectorType().getShape()[i + 1]; 197c577f91dSCharitha Saumya } 198c577f91dSCharitha Saumya // Final shuffle indices has nExtractedSlices * extractGranularitySize 199c577f91dSCharitha Saumya // elements. 200c577f91dSCharitha Saumya llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * 201c577f91dSCharitha Saumya extractGranularitySize); 202c577f91dSCharitha Saumya // Compute the strides of the extracted kD vector. 203c577f91dSCharitha Saumya llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); 204c577f91dSCharitha Saumya // Compute extractedStrides. 205c577f91dSCharitha Saumya for (int i = kD - 2; i >= 0; --i) { 206c577f91dSCharitha Saumya extractedStrides[i] = 207fac349a1SChristian Sigg extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); 208c577f91dSCharitha Saumya } 209c577f91dSCharitha Saumya // Iterate over all extracted slices from 0 to nExtractedSlices - 1 210c577f91dSCharitha Saumya // and compute the multi-dimensional index and the corresponding linearized 211c577f91dSCharitha Saumya // index within the source vector. 212c577f91dSCharitha Saumya for (int64_t i = 0; i < nExtractedSlices; ++i) { 213c577f91dSCharitha Saumya int64_t index = i; 214c577f91dSCharitha Saumya // Compute the corresponding multi-dimensional index. 215c577f91dSCharitha Saumya llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); 216c577f91dSCharitha Saumya for (int64_t j = 0; j < kD; ++j) { 217c577f91dSCharitha Saumya multiDimIndex[j] = (index / extractedStrides[j]); 218c577f91dSCharitha Saumya index -= multiDimIndex[j] * extractedStrides[j]; 219c577f91dSCharitha Saumya } 220c577f91dSCharitha Saumya // Compute the corresponding linearized index in the source vector 221c577f91dSCharitha Saumya // i.e. shift the multiDimIndex by the offsets. 222c577f91dSCharitha Saumya int64_t linearizedIndex = 0; 223c577f91dSCharitha Saumya for (int64_t j = 0; j < kD; ++j) { 224c577f91dSCharitha Saumya linearizedIndex += 225fac349a1SChristian Sigg (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * 226c577f91dSCharitha Saumya sourceStrides[j]; 227c577f91dSCharitha Saumya } 228c577f91dSCharitha Saumya // Fill the indices array form linearizedIndex to linearizedIndex + 229c577f91dSCharitha Saumya // extractGranularitySize. 230c577f91dSCharitha Saumya for (int64_t j = 0; j < extractGranularitySize; ++j) { 231c577f91dSCharitha Saumya indices[i * extractGranularitySize + j] = linearizedIndex + j; 232c577f91dSCharitha Saumya } 233c577f91dSCharitha Saumya } 234c577f91dSCharitha Saumya // Perform a shuffle to extract the kD vector. 235c577f91dSCharitha Saumya rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 236b4444dcaSBenjamin Maxwell extractOp, dstType, srcVector, srcVector, indices); 237c577f91dSCharitha Saumya return success(); 238c577f91dSCharitha Saumya } 239c577f91dSCharitha Saumya 240c577f91dSCharitha Saumya private: 241c577f91dSCharitha Saumya unsigned targetVectorBitWidth; 242c577f91dSCharitha Saumya }; 243c577f91dSCharitha Saumya 244c577f91dSCharitha Saumya /// This pattern converts the ShuffleOp that works on nD (n > 1) 245c577f91dSCharitha Saumya /// vectors to a ShuffleOp that works on linearized vectors. 246c577f91dSCharitha Saumya /// Following, 247c577f91dSCharitha Saumya /// vector.shuffle %v1, %v2 [ shuffle_indices ] 248c577f91dSCharitha Saumya /// is converted to : 249c577f91dSCharitha Saumya /// %v1_1d = vector.shape_cast %v1 250c577f91dSCharitha Saumya /// %v2_1d = vector.shape_cast %v2 251c577f91dSCharitha Saumya /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] 252c577f91dSCharitha Saumya /// %out_nd = vector.shape_cast %out_1d 253c577f91dSCharitha Saumya // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` 254c577f91dSCharitha Saumya /// of the original shuffle operation. 255c577f91dSCharitha Saumya struct LinearizeVectorShuffle final 256c577f91dSCharitha Saumya : public OpConversionPattern<vector::ShuffleOp> { 257c577f91dSCharitha Saumya using OpConversionPattern::OpConversionPattern; 258c577f91dSCharitha Saumya LinearizeVectorShuffle( 259c577f91dSCharitha Saumya const TypeConverter &typeConverter, MLIRContext *context, 260c577f91dSCharitha Saumya unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 261c577f91dSCharitha Saumya PatternBenefit benefit = 1) 262c577f91dSCharitha Saumya : OpConversionPattern(typeConverter, context, benefit), 263c577f91dSCharitha Saumya targetVectorBitWidth(targetVectBitWidth) {} 264c577f91dSCharitha Saumya 265c577f91dSCharitha Saumya LogicalResult 266c577f91dSCharitha Saumya matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, 267c577f91dSCharitha Saumya ConversionPatternRewriter &rewriter) const override { 26874a105adSArtem Kroviakov VectorType dstType = 26974a105adSArtem Kroviakov getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); 27074a105adSArtem Kroviakov assert(dstType && "vector type destination expected."); 27174a105adSArtem Kroviakov // The assert is used because vector.shuffle does not support scalable 27274a105adSArtem Kroviakov // vectors. 273c577f91dSCharitha Saumya assert(!(shuffleOp.getV1VectorType().isScalable() || 274c577f91dSCharitha Saumya shuffleOp.getV2VectorType().isScalable() || 27574a105adSArtem Kroviakov dstType.isScalable()) && 276c577f91dSCharitha Saumya "scalable vectors are not supported."); 277c577f91dSCharitha Saumya if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) 278c577f91dSCharitha Saumya return rewriter.notifyMatchFailure( 279c577f91dSCharitha Saumya shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); 280c577f91dSCharitha Saumya 281c577f91dSCharitha Saumya Value vec1 = adaptor.getV1(); 282c577f91dSCharitha Saumya Value vec2 = adaptor.getV2(); 283c577f91dSCharitha Saumya int shuffleSliceLen = 1; 284c577f91dSCharitha Saumya int rank = shuffleOp.getV1().getType().getRank(); 285c577f91dSCharitha Saumya 286c577f91dSCharitha Saumya // If rank > 1, we need to do the shuffle in the granularity of slices 287c577f91dSCharitha Saumya // instead of scalars. Size of the slice is equal to the rank-1 innermost 288c577f91dSCharitha Saumya // dims. Mask of the shuffle op specifies which slice to take from the 289c577f91dSCharitha Saumya // outermost dim. 290c577f91dSCharitha Saumya if (rank > 1) { 291c577f91dSCharitha Saumya llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); 292c577f91dSCharitha Saumya for (unsigned i = 1; i < shape.size(); ++i) { 293c577f91dSCharitha Saumya shuffleSliceLen *= shape[i]; 294c577f91dSCharitha Saumya } 295c577f91dSCharitha Saumya } 296c577f91dSCharitha Saumya 297c577f91dSCharitha Saumya // For each value in the mask, we generate the indices of the source vectors 298c577f91dSCharitha Saumya // that needs to be shuffled to the destination vector. If shuffleSliceLen > 299c577f91dSCharitha Saumya // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of 300c577f91dSCharitha Saumya // elements) instead of scalars. 301b4444dcaSBenjamin Maxwell ArrayRef<int64_t> mask = shuffleOp.getMask(); 302c577f91dSCharitha Saumya int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; 303c577f91dSCharitha Saumya llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); 304b4444dcaSBenjamin Maxwell for (auto [i, value] : llvm::enumerate(mask)) { 305c577f91dSCharitha Saumya std::iota(indices.begin() + shuffleSliceLen * i, 306c577f91dSCharitha Saumya indices.begin() + shuffleSliceLen * (i + 1), 307b4444dcaSBenjamin Maxwell shuffleSliceLen * value); 308c577f91dSCharitha Saumya } 309c577f91dSCharitha Saumya 310b4444dcaSBenjamin Maxwell rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1, 311b4444dcaSBenjamin Maxwell vec2, indices); 312c577f91dSCharitha Saumya return success(); 313c577f91dSCharitha Saumya } 314c577f91dSCharitha Saumya 315c577f91dSCharitha Saumya private: 316c577f91dSCharitha Saumya unsigned targetVectorBitWidth; 317c577f91dSCharitha Saumya }; 318c577f91dSCharitha Saumya 319c577f91dSCharitha Saumya /// This pattern converts the ExtractOp to a ShuffleOp that works on a 320c577f91dSCharitha Saumya /// linearized vector. 321c577f91dSCharitha Saumya /// Following, 322c577f91dSCharitha Saumya /// vector.extract %source [ position ] 323c577f91dSCharitha Saumya /// is converted to : 324c577f91dSCharitha Saumya /// %source_1d = vector.shape_cast %source 325c577f91dSCharitha Saumya /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] 326c577f91dSCharitha Saumya /// %out_nd = vector.shape_cast %out_1d 327c577f91dSCharitha Saumya /// `shuffle_indices_1d` is computed using the position of the original extract. 328c577f91dSCharitha Saumya struct LinearizeVectorExtract final 329c577f91dSCharitha Saumya : public OpConversionPattern<vector::ExtractOp> { 330c577f91dSCharitha Saumya using OpConversionPattern::OpConversionPattern; 331c577f91dSCharitha Saumya LinearizeVectorExtract( 332c577f91dSCharitha Saumya const TypeConverter &typeConverter, MLIRContext *context, 333c577f91dSCharitha Saumya unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 334c577f91dSCharitha Saumya PatternBenefit benefit = 1) 335c577f91dSCharitha Saumya : OpConversionPattern(typeConverter, context, benefit), 336c577f91dSCharitha Saumya targetVectorBitWidth(targetVectBitWidth) {} 337c577f91dSCharitha Saumya LogicalResult 338c577f91dSCharitha Saumya matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 339c577f91dSCharitha Saumya ConversionPatternRewriter &rewriter) const override { 340c577f91dSCharitha Saumya Type dstTy = getTypeConverter()->convertType(extractOp.getType()); 34150febdebSLongsheng Mou if (!dstTy) 34250febdebSLongsheng Mou return rewriter.notifyMatchFailure(extractOp, 34350febdebSLongsheng Mou "expected n-D vector type."); 34450febdebSLongsheng Mou 34574a105adSArtem Kroviakov if (extractOp.getVector().getType().isScalable() || 34674a105adSArtem Kroviakov cast<VectorType>(dstTy).isScalable()) 34774a105adSArtem Kroviakov return rewriter.notifyMatchFailure(extractOp, 348c577f91dSCharitha Saumya "scalable vectors are not supported."); 349c577f91dSCharitha Saumya if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) 350c577f91dSCharitha Saumya return rewriter.notifyMatchFailure( 351c577f91dSCharitha Saumya extractOp, "Can't flatten since targetBitWidth <= OpSize"); 352c577f91dSCharitha Saumya 353c577f91dSCharitha Saumya // Dynamic position is not supported. 354c577f91dSCharitha Saumya if (extractOp.hasDynamicPosition()) 355c577f91dSCharitha Saumya return rewriter.notifyMatchFailure(extractOp, 356c577f91dSCharitha Saumya "dynamic position is not supported."); 357c577f91dSCharitha Saumya 358c577f91dSCharitha Saumya llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); 359c577f91dSCharitha Saumya int64_t size = extractOp.getVector().getType().getNumElements(); 360c577f91dSCharitha Saumya 361c577f91dSCharitha Saumya // Compute linearized offset. 362c577f91dSCharitha Saumya int64_t linearizedOffset = 0; 363c577f91dSCharitha Saumya llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); 364c577f91dSCharitha Saumya for (auto [i, off] : llvm::enumerate(offsets)) { 365c577f91dSCharitha Saumya size /= shape[i]; 366c577f91dSCharitha Saumya linearizedOffset += offsets[i] * size; 367c577f91dSCharitha Saumya } 368c577f91dSCharitha Saumya 369c577f91dSCharitha Saumya llvm::SmallVector<int64_t, 2> indices(size); 370c577f91dSCharitha Saumya std::iota(indices.begin(), indices.end(), linearizedOffset); 371c577f91dSCharitha Saumya rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 372b4444dcaSBenjamin Maxwell extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); 373c577f91dSCharitha Saumya 374c577f91dSCharitha Saumya return success(); 375c577f91dSCharitha Saumya } 376c577f91dSCharitha Saumya 377c577f91dSCharitha Saumya private: 378c577f91dSCharitha Saumya unsigned targetVectorBitWidth; 379c577f91dSCharitha Saumya }; 38001fbc565SArtem Kroviakov 38101fbc565SArtem Kroviakov /// This pattern converts the InsertOp to a ShuffleOp that works on a 38201fbc565SArtem Kroviakov /// linearized vector. 38301fbc565SArtem Kroviakov /// Following, 38401fbc565SArtem Kroviakov /// vector.insert %source %destination [ position ] 38501fbc565SArtem Kroviakov /// is converted to : 38601fbc565SArtem Kroviakov /// %source_1d = vector.shape_cast %source 38701fbc565SArtem Kroviakov /// %destination_1d = vector.shape_cast %destination 38801fbc565SArtem Kroviakov /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d 38901fbc565SArtem Kroviakov /// ] %out_nd = vector.shape_cast %out_1d 39001fbc565SArtem Kroviakov /// `shuffle_indices_1d` is computed using the position of the original insert. 39101fbc565SArtem Kroviakov struct LinearizeVectorInsert final 39201fbc565SArtem Kroviakov : public OpConversionPattern<vector::InsertOp> { 39301fbc565SArtem Kroviakov using OpConversionPattern::OpConversionPattern; 39401fbc565SArtem Kroviakov LinearizeVectorInsert( 39501fbc565SArtem Kroviakov const TypeConverter &typeConverter, MLIRContext *context, 39601fbc565SArtem Kroviakov unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 39701fbc565SArtem Kroviakov PatternBenefit benefit = 1) 39801fbc565SArtem Kroviakov : OpConversionPattern(typeConverter, context, benefit), 39901fbc565SArtem Kroviakov targetVectorBitWidth(targetVectBitWidth) {} 40001fbc565SArtem Kroviakov LogicalResult 40101fbc565SArtem Kroviakov matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 40201fbc565SArtem Kroviakov ConversionPatternRewriter &rewriter) const override { 40374a105adSArtem Kroviakov VectorType dstTy = getTypeConverter()->convertType<VectorType>( 40474a105adSArtem Kroviakov insertOp.getDestVectorType()); 40574a105adSArtem Kroviakov assert(dstTy && "vector type destination expected."); 40674a105adSArtem Kroviakov if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) 40774a105adSArtem Kroviakov return rewriter.notifyMatchFailure(insertOp, 40801fbc565SArtem Kroviakov "scalable vectors are not supported."); 40901fbc565SArtem Kroviakov 41001fbc565SArtem Kroviakov if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), 41101fbc565SArtem Kroviakov targetVectorBitWidth)) 41201fbc565SArtem Kroviakov return rewriter.notifyMatchFailure( 41301fbc565SArtem Kroviakov insertOp, "Can't flatten since targetBitWidth < OpSize"); 41401fbc565SArtem Kroviakov 41501fbc565SArtem Kroviakov // dynamic position is not supported 41601fbc565SArtem Kroviakov if (insertOp.hasDynamicPosition()) 41701fbc565SArtem Kroviakov return rewriter.notifyMatchFailure(insertOp, 41801fbc565SArtem Kroviakov "dynamic position is not supported."); 41901fbc565SArtem Kroviakov auto srcTy = insertOp.getSourceType(); 42001fbc565SArtem Kroviakov auto srcAsVec = dyn_cast<VectorType>(srcTy); 42101fbc565SArtem Kroviakov uint64_t srcSize = 0; 42201fbc565SArtem Kroviakov if (srcAsVec) { 42301fbc565SArtem Kroviakov srcSize = srcAsVec.getNumElements(); 42401fbc565SArtem Kroviakov } else { 42501fbc565SArtem Kroviakov return rewriter.notifyMatchFailure(insertOp, 42601fbc565SArtem Kroviakov "scalars are not supported."); 42701fbc565SArtem Kroviakov } 42801fbc565SArtem Kroviakov 42901fbc565SArtem Kroviakov auto dstShape = insertOp.getDestVectorType().getShape(); 43001fbc565SArtem Kroviakov const auto dstSize = insertOp.getDestVectorType().getNumElements(); 43101fbc565SArtem Kroviakov auto dstSizeForOffsets = dstSize; 43201fbc565SArtem Kroviakov 43301fbc565SArtem Kroviakov // compute linearized offset 43401fbc565SArtem Kroviakov int64_t linearizedOffset = 0; 43501fbc565SArtem Kroviakov auto offsetsNd = insertOp.getStaticPosition(); 43601fbc565SArtem Kroviakov for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { 43701fbc565SArtem Kroviakov dstSizeForOffsets /= dstShape[dim]; 43801fbc565SArtem Kroviakov linearizedOffset += offset * dstSizeForOffsets; 43901fbc565SArtem Kroviakov } 44001fbc565SArtem Kroviakov 44101fbc565SArtem Kroviakov llvm::SmallVector<int64_t, 2> indices(dstSize); 44201fbc565SArtem Kroviakov auto origValsUntil = indices.begin(); 44301fbc565SArtem Kroviakov std::advance(origValsUntil, linearizedOffset); 44401fbc565SArtem Kroviakov std::iota(indices.begin(), origValsUntil, 44501fbc565SArtem Kroviakov 0); // original values that remain [0, offset) 44601fbc565SArtem Kroviakov auto newValsUntil = origValsUntil; 44701fbc565SArtem Kroviakov std::advance(newValsUntil, srcSize); 44801fbc565SArtem Kroviakov std::iota(origValsUntil, newValsUntil, 44901fbc565SArtem Kroviakov dstSize); // new values [offset, offset+srcNumElements) 45001fbc565SArtem Kroviakov std::iota(newValsUntil, indices.end(), 45101fbc565SArtem Kroviakov linearizedOffset + srcSize); // the rest of original values 45201fbc565SArtem Kroviakov // [offset+srcNumElements, end) 45301fbc565SArtem Kroviakov 45401fbc565SArtem Kroviakov rewriter.replaceOpWithNewOp<vector::ShuffleOp>( 455b4444dcaSBenjamin Maxwell insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices); 45601fbc565SArtem Kroviakov 45701fbc565SArtem Kroviakov return success(); 45801fbc565SArtem Kroviakov } 45901fbc565SArtem Kroviakov 46001fbc565SArtem Kroviakov private: 46101fbc565SArtem Kroviakov unsigned targetVectorBitWidth; 46201fbc565SArtem Kroviakov }; 463*bd5d361cSChao Chen 464*bd5d361cSChao Chen /// This pattern converts the BitCastOp that works on nD (n > 1) 465*bd5d361cSChao Chen /// vectors to a BitCastOp that works on linearized vectors. 466*bd5d361cSChao Chen /// Following, 467*bd5d361cSChao Chen /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> 468*bd5d361cSChao Chen /// is converted to : 469*bd5d361cSChao Chen /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> 470*bd5d361cSChao Chen /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> 471*bd5d361cSChao Chen /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> 472*bd5d361cSChao Chen struct LinearizeVectorBitCast final 473*bd5d361cSChao Chen : public OpConversionPattern<vector::BitCastOp> { 474*bd5d361cSChao Chen using OpConversionPattern::OpConversionPattern; 475*bd5d361cSChao Chen LinearizeVectorBitCast( 476*bd5d361cSChao Chen const TypeConverter &typeConverter, MLIRContext *context, 477*bd5d361cSChao Chen unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), 478*bd5d361cSChao Chen PatternBenefit benefit = 1) 479*bd5d361cSChao Chen : OpConversionPattern(typeConverter, context, benefit), 480*bd5d361cSChao Chen targetVectorBitWidth(targetVectBitWidth) {} 481*bd5d361cSChao Chen LogicalResult 482*bd5d361cSChao Chen matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, 483*bd5d361cSChao Chen ConversionPatternRewriter &rewriter) const override { 484*bd5d361cSChao Chen Location loc = castOp.getLoc(); 485*bd5d361cSChao Chen auto resType = getTypeConverter()->convertType(castOp.getType()); 486*bd5d361cSChao Chen if (!resType) 487*bd5d361cSChao Chen return rewriter.notifyMatchFailure(loc, "can't convert return type."); 488*bd5d361cSChao Chen 489*bd5d361cSChao Chen if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) 490*bd5d361cSChao Chen return rewriter.notifyMatchFailure( 491*bd5d361cSChao Chen loc, "Can't flatten since targetBitWidth <= OpSize"); 492*bd5d361cSChao Chen 493*bd5d361cSChao Chen rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, 494*bd5d361cSChao Chen adaptor.getSource()); 495*bd5d361cSChao Chen return mlir::success(); 496*bd5d361cSChao Chen } 497*bd5d361cSChao Chen 498*bd5d361cSChao Chen private: 499*bd5d361cSChao Chen unsigned targetVectorBitWidth; 500*bd5d361cSChao Chen }; 501*bd5d361cSChao Chen 50235ef3994SIvan Butygin } // namespace 50335ef3994SIvan Butygin 50435ef3994SIvan Butygin void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 50535ef3994SIvan Butygin TypeConverter &typeConverter, RewritePatternSet &patterns, 5066f5c4f2eSBalaji V. Iyer ConversionTarget &target, unsigned targetBitWidth) { 5076f5c4f2eSBalaji V. Iyer 50835ef3994SIvan Butygin typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 509d3aa92edSAndrzej Warzyński if (!isLinearizableVector(type)) 51035ef3994SIvan Butygin return type; 51135ef3994SIvan Butygin 512d3aa92edSAndrzej Warzyński return VectorType::get(type.getNumElements(), type.getElementType(), 513d3aa92edSAndrzej Warzyński type.isScalable()); 51435ef3994SIvan Butygin }); 51535ef3994SIvan Butygin 51635ef3994SIvan Butygin auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 51735ef3994SIvan Butygin Location loc) -> Value { 51835ef3994SIvan Butygin if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 51935ef3994SIvan Butygin !isa<VectorType>(type)) 52035ef3994SIvan Butygin return nullptr; 52135ef3994SIvan Butygin 52235ef3994SIvan Butygin return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 52335ef3994SIvan Butygin }; 52435ef3994SIvan Butygin typeConverter.addSourceMaterialization(materializeCast); 52535ef3994SIvan Butygin typeConverter.addTargetMaterialization(materializeCast); 52635ef3994SIvan Butygin target.markUnknownOpDynamicallyLegal( 5276f5c4f2eSBalaji V. Iyer [=](Operation *op) -> std::optional<bool> { 528*bd5d361cSChao Chen if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) || 5296f5c4f2eSBalaji V. Iyer op->hasTrait<OpTrait::Vectorizable>())) { 5306f5c4f2eSBalaji V. Iyer return (isLessThanTargetBitWidth(op, targetBitWidth) 5316f5c4f2eSBalaji V. Iyer ? typeConverter.isLegal(op) 5326f5c4f2eSBalaji V. Iyer : true); 5336f5c4f2eSBalaji V. Iyer } 53435ef3994SIvan Butygin return std::nullopt; 53535ef3994SIvan Butygin }); 53635ef3994SIvan Butygin 537*bd5d361cSChao Chen patterns 538*bd5d361cSChao Chen .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>( 5396f5c4f2eSBalaji V. Iyer typeConverter, patterns.getContext(), targetBitWidth); 54035ef3994SIvan Butygin } 541c577f91dSCharitha Saumya 542c577f91dSCharitha Saumya void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( 543206fad0eSMatthias Springer const TypeConverter &typeConverter, RewritePatternSet &patterns, 544c577f91dSCharitha Saumya ConversionTarget &target, unsigned int targetBitWidth) { 545c577f91dSCharitha Saumya target.addDynamicallyLegalOp<vector::ShuffleOp>( 546c577f91dSCharitha Saumya [=](vector::ShuffleOp shuffleOp) -> bool { 547c577f91dSCharitha Saumya return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) 548c577f91dSCharitha Saumya ? (typeConverter.isLegal(shuffleOp) && 549fac349a1SChristian Sigg cast<mlir::VectorType>(shuffleOp.getResult().getType()) 550c577f91dSCharitha Saumya .getRank() == 1) 551c577f91dSCharitha Saumya : true; 552c577f91dSCharitha Saumya }); 553c577f91dSCharitha Saumya patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, 55401fbc565SArtem Kroviakov LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( 555c577f91dSCharitha Saumya typeConverter, patterns.getContext(), targetBitWidth); 556c577f91dSCharitha Saumya } 557