//===- VectorLinearize.cpp - vector linearization transforms --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns and pass for linearizing ND vectors into 1D. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include #include using namespace mlir; static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); // Reject index since getElementTypeBitWidth will abort for Index types. if (!vecType || vecType.getElementType().isIndex()) return false; // There are no dimension to fold if it is a 0-D vector. if (vecType.getRank() == 0) return false; unsigned trailingVecDimBitWidth = vecType.getShape().back() * vecType.getElementTypeBitWidth(); if (trailingVecDimBitWidth >= targetBitWidth) return false; } return true; } static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { VectorType vecType = dyn_cast(t); // Reject index since getElementTypeBitWidth will abort for Index types. if (!vecType || vecType.getElementType().isIndex()) return false; // There are no dimension to fold if it is a 0-D vector. if (vecType.getRank() == 0) return false; unsigned trailingVecDimBitWidth = vecType.getShape().back() * vecType.getElementTypeBitWidth(); return trailingVecDimBitWidth <= targetBitWidth; } namespace { struct LinearizeConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeConstant( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = constOp.getLoc(); auto resType = getTypeConverter()->convertType(constOp.getType()); if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type"); if (resType.isScalable() && !isa(constOp.getValue())) return rewriter.notifyMatchFailure( loc, "Cannot linearize a constant scalable vector that's not a splat"); if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); auto dstElementsAttr = dyn_cast(constOp.getValue()); if (!dstElementsAttr) return rewriter.notifyMatchFailure(loc, "unsupported attr type"); dstElementsAttr = dstElementsAttr.reshape(resType); rewriter.replaceOpWithNewOp(constOp, resType, dstElementsAttr); return success(); } private: unsigned targetVectorBitWidth; }; struct LinearizeVectorizable final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) return rewriter.notifyMatchFailure( op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); FailureOr newOp = convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); if (failed(newOp)) return failure(); rewriter.replaceOp(op, (*newOp)->getResults()); return success(); } private: unsigned targetVectorBitWidth; }; /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works /// on a linearized vector. /// Following, /// vector.extract_strided_slice %source /// { offsets = [..], strides = [..], sizes = [..] } /// is converted to : /// %source_1d = vector.shape_cast %source /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d /// `shuffle_indices_1d` is computed using the offsets and sizes of the /// extraction. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstType = getTypeConverter()->convertType(extractOp.getType()); assert(dstType && "vector type destination expected."); if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); ArrayAttr offsets = extractOp.getOffsets(); ArrayAttr sizes = extractOp.getSizes(); ArrayAttr strides = extractOp.getStrides(); if (!isConstantIntValue(strides[0], 1)) return rewriter.notifyMatchFailure( extractOp, "Strided slice with stride != 1 is not supported."); Value srcVector = adaptor.getVector(); // If kD offsets are specified for nD source vector (n > k), the granularity // of the extraction is greater than 1. In this case last (n-k) dimensions // form the extraction granularity. // Example : // vector.extract_strided_slice %src { // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : // vector<4x8x8xf32> to vector<2x2x8xf32> // Here, extraction granularity is 8. int64_t extractGranularitySize = 1; int64_t nD = extractOp.getSourceVectorType().getRank(); int64_t kD = (int64_t)offsets.size(); int64_t k = kD; while (k < nD) { extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; ++k; } // Get total number of extracted slices. int64_t nExtractedSlices = 1; for (Attribute size : sizes) { nExtractedSlices *= cast(size).getInt(); } // Compute the strides of the source vector considering first k dimensions. llvm::SmallVector sourceStrides(kD, extractGranularitySize); for (int i = kD - 2; i >= 0; --i) { sourceStrides[i] = sourceStrides[i + 1] * extractOp.getSourceVectorType().getShape()[i + 1]; } // Final shuffle indices has nExtractedSlices * extractGranularitySize // elements. llvm::SmallVector indices(nExtractedSlices * extractGranularitySize); // Compute the strides of the extracted kD vector. llvm::SmallVector extractedStrides(kD, 1); // Compute extractedStrides. for (int i = kD - 2; i >= 0; --i) { extractedStrides[i] = extractedStrides[i + 1] * cast(sizes[i + 1]).getInt(); } // Iterate over all extracted slices from 0 to nExtractedSlices - 1 // and compute the multi-dimensional index and the corresponding linearized // index within the source vector. for (int64_t i = 0; i < nExtractedSlices; ++i) { int64_t index = i; // Compute the corresponding multi-dimensional index. llvm::SmallVector multiDimIndex(kD, 0); for (int64_t j = 0; j < kD; ++j) { multiDimIndex[j] = (index / extractedStrides[j]); index -= multiDimIndex[j] * extractedStrides[j]; } // Compute the corresponding linearized index in the source vector // i.e. shift the multiDimIndex by the offsets. int64_t linearizedIndex = 0; for (int64_t j = 0; j < kD; ++j) { linearizedIndex += (cast(offsets[j]).getInt() + multiDimIndex[j]) * sourceStrides[j]; } // Fill the indices array form linearizedIndex to linearizedIndex + // extractGranularitySize. for (int64_t j = 0; j < extractGranularitySize; ++j) { indices[i * extractGranularitySize + j] = linearizedIndex + j; } } // Perform a shuffle to extract the kD vector. rewriter.replaceOpWithNewOp( extractOp, dstType, srcVector, srcVector, indices); return success(); } private: unsigned targetVectorBitWidth; }; /// This pattern converts the ShuffleOp that works on nD (n > 1) /// vectors to a ShuffleOp that works on linearized vectors. /// Following, /// vector.shuffle %v1, %v2 [ shuffle_indices ] /// is converted to : /// %v1_1d = vector.shape_cast %v1 /// %v2_1d = vector.shape_cast %v2 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstType = getTypeConverter()->convertType(shuffleOp.getType()); assert(dstType && "vector type destination expected."); // The assert is used because vector.shuffle does not support scalable // vectors. assert(!(shuffleOp.getV1VectorType().isScalable() || shuffleOp.getV2VectorType().isScalable() || dstType.isScalable()) && "scalable vectors are not supported."); if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); Value vec1 = adaptor.getV1(); Value vec2 = adaptor.getV2(); int shuffleSliceLen = 1; int rank = shuffleOp.getV1().getType().getRank(); // If rank > 1, we need to do the shuffle in the granularity of slices // instead of scalars. Size of the slice is equal to the rank-1 innermost // dims. Mask of the shuffle op specifies which slice to take from the // outermost dim. if (rank > 1) { llvm::ArrayRef shape = shuffleOp.getV1().getType().getShape(); for (unsigned i = 1; i < shape.size(); ++i) { shuffleSliceLen *= shape[i]; } } // For each value in the mask, we generate the indices of the source vectors // that needs to be shuffled to the destination vector. If shuffleSliceLen > // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of // elements) instead of scalars. ArrayRef mask = shuffleOp.getMask(); int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; llvm::SmallVector indices(totalSizeOfShuffledElmnts); for (auto [i, value] : llvm::enumerate(mask)) { std::iota(indices.begin() + shuffleSliceLen * i, indices.begin() + shuffleSliceLen * (i + 1), shuffleSliceLen * value); } rewriter.replaceOpWithNewOp(shuffleOp, dstType, vec1, vec2, indices); return success(); } private: unsigned targetVectorBitWidth; }; /// This pattern converts the ExtractOp to a ShuffleOp that works on a /// linearized vector. /// Following, /// vector.extract %source [ position ] /// is converted to : /// %source_1d = vector.shape_cast %source /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d /// `shuffle_indices_1d` is computed using the position of the original extract. struct LinearizeVectorExtract final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstTy = getTypeConverter()->convertType(extractOp.getType()); if (!dstTy) return rewriter.notifyMatchFailure(extractOp, "expected n-D vector type."); if (extractOp.getVector().getType().isScalable() || cast(dstTy).isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); // Dynamic position is not supported. if (extractOp.hasDynamicPosition()) return rewriter.notifyMatchFailure(extractOp, "dynamic position is not supported."); llvm::ArrayRef shape = extractOp.getVector().getType().getShape(); int64_t size = extractOp.getVector().getType().getNumElements(); // Compute linearized offset. int64_t linearizedOffset = 0; llvm::ArrayRef offsets = extractOp.getStaticPosition(); for (auto [i, off] : llvm::enumerate(offsets)) { size /= shape[i]; linearizedOffset += offsets[i] * size; } llvm::SmallVector indices(size); std::iota(indices.begin(), indices.end(), linearizedOffset); rewriter.replaceOpWithNewOp( extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); return success(); } private: unsigned targetVectorBitWidth; }; /// This pattern converts the InsertOp to a ShuffleOp that works on a /// linearized vector. /// Following, /// vector.insert %source %destination [ position ] /// is converted to : /// %source_1d = vector.shape_cast %source /// %destination_1d = vector.shape_cast %destination /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d /// ] %out_nd = vector.shape_cast %out_1d /// `shuffle_indices_1d` is computed using the position of the original insert. struct LinearizeVectorInsert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstTy = getTypeConverter()->convertType( insertOp.getDestVectorType()); assert(dstTy && "vector type destination expected."); if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) return rewriter.notifyMatchFailure(insertOp, "scalable vectors are not supported."); if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), targetVectorBitWidth)) return rewriter.notifyMatchFailure( insertOp, "Can't flatten since targetBitWidth < OpSize"); // dynamic position is not supported if (insertOp.hasDynamicPosition()) return rewriter.notifyMatchFailure(insertOp, "dynamic position is not supported."); auto srcTy = insertOp.getSourceType(); auto srcAsVec = dyn_cast(srcTy); uint64_t srcSize = 0; if (srcAsVec) { srcSize = srcAsVec.getNumElements(); } else { return rewriter.notifyMatchFailure(insertOp, "scalars are not supported."); } auto dstShape = insertOp.getDestVectorType().getShape(); const auto dstSize = insertOp.getDestVectorType().getNumElements(); auto dstSizeForOffsets = dstSize; // compute linearized offset int64_t linearizedOffset = 0; auto offsetsNd = insertOp.getStaticPosition(); for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { dstSizeForOffsets /= dstShape[dim]; linearizedOffset += offset * dstSizeForOffsets; } llvm::SmallVector indices(dstSize); auto origValsUntil = indices.begin(); std::advance(origValsUntil, linearizedOffset); std::iota(indices.begin(), origValsUntil, 0); // original values that remain [0, offset) auto newValsUntil = origValsUntil; std::advance(newValsUntil, srcSize); std::iota(origValsUntil, newValsUntil, dstSize); // new values [offset, offset+srcNumElements) std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize); // the rest of original values // [offset+srcNumElements, end) rewriter.replaceOpWithNewOp( insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices); return success(); } private: unsigned targetVectorBitWidth; }; /// This pattern converts the BitCastOp that works on nD (n > 1) /// vectors to a BitCastOp that works on linearized vectors. /// Following, /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> /// is converted to : /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorBitCast( const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = castOp.getLoc(); auto resType = getTypeConverter()->convertType(castOp.getType()); if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type."); if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) return rewriter.notifyMatchFailure( loc, "Can't flatten since targetBitWidth <= OpSize"); rewriter.replaceOpWithNewOp(castOp, resType, adaptor.getSource()); return mlir::success(); } private: unsigned targetVectorBitWidth; }; } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth) { typeConverter.addConversion([](VectorType type) -> std::optional { if (!isLinearizableVector(type)) return type; return VectorType::get(type.getNumElements(), type.getElementType(), type.isScalable()); }); auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1 || !isa(inputs.front().getType()) || !isa(type)) return nullptr; return builder.create(loc, type, inputs.front()); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { if ((isa(op) || isa(op) || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) ? typeConverter.isLegal(op) : true); } return std::nullopt; }); patterns .add( typeConverter, patterns.getContext(), targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned int targetBitWidth) { target.addDynamicallyLegalOp( [=](vector::ShuffleOp shuffleOp) -> bool { return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) ? (typeConverter.isLegal(shuffleOp) && cast(shuffleOp.getResult().getType()) .getRank() == 1) : true; }); patterns.add( typeConverter, patterns.getContext(), targetBitWidth); }