199ef9eebSMatthias Springer //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 299ef9eebSMatthias Springer // 399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699ef9eebSMatthias Springer // 799ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 899ef9eebSMatthias Springer 9796d48b0SQuinn Dawkins #include <numeric> 10796d48b0SQuinn Dawkins 1198f6289aSDiego Caballero #include "mlir/Dialect/Arith/IR/Arith.h" 12ad9b5a4bSNirvedh #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 13e54236dfSLei Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h" 1499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 1501128d4bSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 1799ef9eebSMatthias Springer #include "mlir/IR/Builders.h" 1899ef9eebSMatthias Springer #include "mlir/IR/TypeUtilities.h" 1999ef9eebSMatthias Springer 2099ef9eebSMatthias Springer #define DEBUG_TYPE "vector-drop-unit-dim" 2199ef9eebSMatthias Springer 2299ef9eebSMatthias Springer using namespace mlir; 2399ef9eebSMatthias Springer using namespace mlir::vector; 2499ef9eebSMatthias Springer 2599ef9eebSMatthias Springer // Trims leading one dimensions from `oldType` and returns the result type. 2699ef9eebSMatthias Springer // Returns `vector<1xT>` if `oldType` only has one element. 2799ef9eebSMatthias Springer static VectorType trimLeadingOneDims(VectorType oldType) { 2899ef9eebSMatthias Springer ArrayRef<int64_t> oldShape = oldType.getShape(); 29576b184dSAndrzej Warzynski ArrayRef<int64_t> newShape = oldShape; 30576b184dSAndrzej Warzynski 31576b184dSAndrzej Warzynski ArrayRef<bool> oldScalableDims = oldType.getScalableDims(); 32576b184dSAndrzej Warzynski ArrayRef<bool> newScalableDims = oldScalableDims; 33576b184dSAndrzej Warzynski 34576b184dSAndrzej Warzynski while (!newShape.empty() && newShape.front() == 1 && 35576b184dSAndrzej Warzynski !newScalableDims.front()) { 36576b184dSAndrzej Warzynski newShape = newShape.drop_front(1); 37576b184dSAndrzej Warzynski newScalableDims = newScalableDims.drop_front(1); 38576b184dSAndrzej Warzynski } 39576b184dSAndrzej Warzynski 4099ef9eebSMatthias Springer // Make sure we have at least 1 dimension per vector type requirements. 41576b184dSAndrzej Warzynski if (newShape.empty()) { 4299ef9eebSMatthias Springer newShape = oldShape.take_back(); 43576b184dSAndrzej Warzynski newScalableDims = oldType.getScalableDims().take_back(); 44576b184dSAndrzej Warzynski } 45576b184dSAndrzej Warzynski return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 4699ef9eebSMatthias Springer } 4799ef9eebSMatthias Springer 4899ef9eebSMatthias Springer /// Return a smallVector of size `rank` containing all zeros. 4999ef9eebSMatthias Springer static SmallVector<int64_t> splatZero(int64_t rank) { 5099ef9eebSMatthias Springer return SmallVector<int64_t>(rank, 0); 5199ef9eebSMatthias Springer } 5299ef9eebSMatthias Springer namespace { 5399ef9eebSMatthias Springer 5499ef9eebSMatthias Springer // Casts away leading one dimensions in vector.extract_strided_slice's vector 55e54236dfSLei Zhang // input by inserting vector.broadcast. 5699ef9eebSMatthias Springer struct CastAwayExtractStridedSliceLeadingOneDim 5799ef9eebSMatthias Springer : public OpRewritePattern<vector::ExtractStridedSliceOp> { 5899ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 5999ef9eebSMatthias Springer 6099ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 6199ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 6299ef9eebSMatthias Springer // vector.extract_strided_slice requires the input and output vector to have 6399ef9eebSMatthias Springer // the same rank. Here we drop leading one dimensions from the input vector 6499ef9eebSMatthias Springer // type to make sure we don't cause mismatch. 65a1aad28dSLei Zhang VectorType oldSrcType = extractOp.getSourceVectorType(); 6699ef9eebSMatthias Springer VectorType newSrcType = trimLeadingOneDims(oldSrcType); 6799ef9eebSMatthias Springer 6899ef9eebSMatthias Springer if (newSrcType.getRank() == oldSrcType.getRank()) 6999ef9eebSMatthias Springer return failure(); 7099ef9eebSMatthias Springer 7199ef9eebSMatthias Springer int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 7299ef9eebSMatthias Springer 7399ef9eebSMatthias Springer VectorType oldDstType = extractOp.getType(); 7499ef9eebSMatthias Springer VectorType newDstType = 7599ef9eebSMatthias Springer VectorType::get(oldDstType.getShape().drop_front(dropCount), 760d72f0beSAndrzej Warzyński oldDstType.getElementType(), 770d72f0beSAndrzej Warzyński oldDstType.getScalableDims().drop_front(dropCount)); 7899ef9eebSMatthias Springer 7999ef9eebSMatthias Springer Location loc = extractOp.getLoc(); 8099ef9eebSMatthias Springer 8199ef9eebSMatthias Springer Value newSrcVector = rewriter.create<vector::ExtractOp>( 827c38fd60SJacques Pienaar loc, extractOp.getVector(), splatZero(dropCount)); 8399ef9eebSMatthias Springer 8499ef9eebSMatthias Springer // The offsets/sizes/strides attribute can have a less number of elements 8599ef9eebSMatthias Springer // than the input vector's rank: it is meant for the leading dimensions. 8699ef9eebSMatthias Springer auto newOffsets = rewriter.getArrayAttr( 877c38fd60SJacques Pienaar extractOp.getOffsets().getValue().drop_front(dropCount)); 8899ef9eebSMatthias Springer auto newSizes = rewriter.getArrayAttr( 897c38fd60SJacques Pienaar extractOp.getSizes().getValue().drop_front(dropCount)); 9099ef9eebSMatthias Springer auto newStrides = rewriter.getArrayAttr( 917c38fd60SJacques Pienaar extractOp.getStrides().getValue().drop_front(dropCount)); 9299ef9eebSMatthias Springer 9399ef9eebSMatthias Springer auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 9499ef9eebSMatthias Springer loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 9599ef9eebSMatthias Springer 9699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 9799ef9eebSMatthias Springer newExtractOp); 9899ef9eebSMatthias Springer 9999ef9eebSMatthias Springer return success(); 10099ef9eebSMatthias Springer } 10199ef9eebSMatthias Springer }; 10299ef9eebSMatthias Springer 103e54236dfSLei Zhang // Casts away leading one dimensions in vector.insert_strided_slice's vector 104e54236dfSLei Zhang // inputs by inserting vector.broadcast. 10599ef9eebSMatthias Springer struct CastAwayInsertStridedSliceLeadingOneDim 10699ef9eebSMatthias Springer : public OpRewritePattern<vector::InsertStridedSliceOp> { 10799ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 10899ef9eebSMatthias Springer 10999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 11099ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 11199ef9eebSMatthias Springer VectorType oldSrcType = insertOp.getSourceVectorType(); 11299ef9eebSMatthias Springer VectorType newSrcType = trimLeadingOneDims(oldSrcType); 11399ef9eebSMatthias Springer VectorType oldDstType = insertOp.getDestVectorType(); 11499ef9eebSMatthias Springer VectorType newDstType = trimLeadingOneDims(oldDstType); 11599ef9eebSMatthias Springer 11699ef9eebSMatthias Springer int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 11799ef9eebSMatthias Springer int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 11899ef9eebSMatthias Springer if (srcDropCount == 0 && dstDropCount == 0) 11999ef9eebSMatthias Springer return failure(); 12099ef9eebSMatthias Springer 12199ef9eebSMatthias Springer // Trim leading one dimensions from both operands. 12299ef9eebSMatthias Springer Location loc = insertOp.getLoc(); 12399ef9eebSMatthias Springer 12499ef9eebSMatthias Springer Value newSrcVector = rewriter.create<vector::ExtractOp>( 1257c38fd60SJacques Pienaar loc, insertOp.getSource(), splatZero(srcDropCount)); 12699ef9eebSMatthias Springer Value newDstVector = rewriter.create<vector::ExtractOp>( 1277c38fd60SJacques Pienaar loc, insertOp.getDest(), splatZero(dstDropCount)); 12899ef9eebSMatthias Springer 12999ef9eebSMatthias Springer auto newOffsets = rewriter.getArrayAttr( 1307c38fd60SJacques Pienaar insertOp.getOffsets().getValue().take_back(newDstType.getRank())); 13199ef9eebSMatthias Springer auto newStrides = rewriter.getArrayAttr( 1327c38fd60SJacques Pienaar insertOp.getStrides().getValue().take_back(newSrcType.getRank())); 13399ef9eebSMatthias Springer 13499ef9eebSMatthias Springer auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 13599ef9eebSMatthias Springer loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 13699ef9eebSMatthias Springer 13799ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 13899ef9eebSMatthias Springer newInsertOp); 13999ef9eebSMatthias Springer 14099ef9eebSMatthias Springer return success(); 14199ef9eebSMatthias Springer } 14299ef9eebSMatthias Springer }; 14399ef9eebSMatthias Springer 144e54236dfSLei Zhang // Casts away leading one dimensions in vector.insert's vector inputs by 145e54236dfSLei Zhang // inserting vector.broadcast. 146e54236dfSLei Zhang struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { 147e54236dfSLei Zhang using OpRewritePattern::OpRewritePattern; 148e54236dfSLei Zhang 149e54236dfSLei Zhang LogicalResult matchAndRewrite(vector::InsertOp insertOp, 150e54236dfSLei Zhang PatternRewriter &rewriter) const override { 151e54236dfSLei Zhang Type oldSrcType = insertOp.getSourceType(); 152e54236dfSLei Zhang Type newSrcType = oldSrcType; 153e54236dfSLei Zhang int64_t oldSrcRank = 0, newSrcRank = 0; 1545550c821STres Popp if (auto type = dyn_cast<VectorType>(oldSrcType)) { 155e54236dfSLei Zhang newSrcType = trimLeadingOneDims(type); 156e54236dfSLei Zhang oldSrcRank = type.getRank(); 1575550c821STres Popp newSrcRank = cast<VectorType>(newSrcType).getRank(); 158e54236dfSLei Zhang } 159e54236dfSLei Zhang 160e54236dfSLei Zhang VectorType oldDstType = insertOp.getDestVectorType(); 161e54236dfSLei Zhang VectorType newDstType = trimLeadingOneDims(oldDstType); 162e54236dfSLei Zhang 163e54236dfSLei Zhang int64_t srcDropCount = oldSrcRank - newSrcRank; 164e54236dfSLei Zhang int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 165e54236dfSLei Zhang if (srcDropCount == 0 && dstDropCount == 0) 166e54236dfSLei Zhang return failure(); 167e54236dfSLei Zhang 168e54236dfSLei Zhang // Trim leading one dimensions from both operands. 169e54236dfSLei Zhang Location loc = insertOp.getLoc(); 170e54236dfSLei Zhang 171e54236dfSLei Zhang Value newSrcVector = insertOp.getSource(); 172e54236dfSLei Zhang if (oldSrcRank != 0) { 173e54236dfSLei Zhang newSrcVector = rewriter.create<vector::ExtractOp>( 174e54236dfSLei Zhang loc, insertOp.getSource(), splatZero(srcDropCount)); 175e54236dfSLei Zhang } 176e54236dfSLei Zhang Value newDstVector = rewriter.create<vector::ExtractOp>( 177e54236dfSLei Zhang loc, insertOp.getDest(), splatZero(dstDropCount)); 178e54236dfSLei Zhang 179942b403fStyb0807 // New position rank needs to be computed in two steps: (1) if destination 180942b403fStyb0807 // type has leading unit dims, we also trim the position array accordingly, 181942b403fStyb0807 // then (2) if source type also has leading unit dims, we need to append 182942b403fStyb0807 // zeroes to the position array accordingly. 18398f6289aSDiego Caballero unsigned oldPosRank = insertOp.getNumIndices(); 184942b403fStyb0807 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount); 18598f6289aSDiego Caballero SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition(); 18698f6289aSDiego Caballero SmallVector<OpFoldResult> newPosition = 18798f6289aSDiego Caballero llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank)); 18898f6289aSDiego Caballero newPosition.resize(newDstType.getRank() - newSrcRank, 18998f6289aSDiego Caballero rewriter.getI64IntegerAttr(0)); 190e54236dfSLei Zhang 191e54236dfSLei Zhang auto newInsertOp = rewriter.create<vector::InsertOp>( 19298f6289aSDiego Caballero loc, newSrcVector, newDstVector, newPosition); 193e54236dfSLei Zhang 194e54236dfSLei Zhang rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 195e54236dfSLei Zhang newInsertOp); 196e54236dfSLei Zhang 197e54236dfSLei Zhang return success(); 198e54236dfSLei Zhang } 199e54236dfSLei Zhang }; 200e54236dfSLei Zhang 2016e8f7d59SQuinn Dawkins static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, 2026e8f7d59SQuinn Dawkins VectorType newType, AffineMap newMap, 2036e8f7d59SQuinn Dawkins VectorType oldMaskType) { 2046e8f7d59SQuinn Dawkins // Infer the type of the new mask from the new map. 2056e8f7d59SQuinn Dawkins VectorType newMaskType = inferTransferOpMaskType(newType, newMap); 2066e8f7d59SQuinn Dawkins 2076e8f7d59SQuinn Dawkins // If the new mask is broadcastable to the old result type, we can safely 2086e8f7d59SQuinn Dawkins // use a `vector.extract` to get the new mask. Otherwise the best we can 2096e8f7d59SQuinn Dawkins // do is shape cast. 2106e8f7d59SQuinn Dawkins if (vector::isBroadcastableTo(newMaskType, oldMaskType) == 2116e8f7d59SQuinn Dawkins BroadcastableToResult::Success) { 2126e8f7d59SQuinn Dawkins int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); 2136e8f7d59SQuinn Dawkins return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim)); 2146e8f7d59SQuinn Dawkins } 2156e8f7d59SQuinn Dawkins return b.create<vector::ShapeCastOp>(loc, newMaskType, mask); 2166e8f7d59SQuinn Dawkins } 2176e8f7d59SQuinn Dawkins 21899ef9eebSMatthias Springer // Turns vector.transfer_read on vector with leading 1 dimensions into 21999ef9eebSMatthias Springer // vector.shape_cast followed by vector.transfer_read on vector without leading 22099ef9eebSMatthias Springer // 1 dimensions. 22199ef9eebSMatthias Springer struct CastAwayTransferReadLeadingOneDim 22299ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> { 22399ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 22499ef9eebSMatthias Springer 22599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp read, 22699ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 227dedc7d4dSJerry Wu // TODO(#78787): Not supported masked op yet. 228dedc7d4dSJerry Wu if (cast<MaskableOpInterface>(read.getOperation()).isMasked()) 229dedc7d4dSJerry Wu return failure(); 23099ef9eebSMatthias Springer // TODO: support 0-d corner case. 23199ef9eebSMatthias Springer if (read.getTransferRank() == 0) 23299ef9eebSMatthias Springer return failure(); 23399ef9eebSMatthias Springer 2345550c821STres Popp auto shapedType = cast<ShapedType>(read.getSource().getType()); 23599ef9eebSMatthias Springer if (shapedType.getElementType() != read.getVectorType().getElementType()) 23699ef9eebSMatthias Springer return failure(); 23799ef9eebSMatthias Springer 23899ef9eebSMatthias Springer VectorType oldType = read.getVectorType(); 23999ef9eebSMatthias Springer VectorType newType = trimLeadingOneDims(oldType); 24099ef9eebSMatthias Springer 24199ef9eebSMatthias Springer if (newType == oldType) 24299ef9eebSMatthias Springer return failure(); 24399ef9eebSMatthias Springer 2447c38fd60SJacques Pienaar AffineMap oldMap = read.getPermutationMap(); 24599ef9eebSMatthias Springer ArrayRef<AffineExpr> newResults = 24699ef9eebSMatthias Springer oldMap.getResults().take_back(newType.getRank()); 24799ef9eebSMatthias Springer AffineMap newMap = 24899ef9eebSMatthias Springer AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 24999ef9eebSMatthias Springer rewriter.getContext()); 25099ef9eebSMatthias Springer 25199ef9eebSMatthias Springer ArrayAttr inBoundsAttr; 2527c38fd60SJacques Pienaar if (read.getInBounds()) 25399ef9eebSMatthias Springer inBoundsAttr = rewriter.getArrayAttr( 2547c38fd60SJacques Pienaar read.getInBoundsAttr().getValue().take_back(newType.getRank())); 25599ef9eebSMatthias Springer 256796d48b0SQuinn Dawkins Value mask = Value(); 257796d48b0SQuinn Dawkins if (read.getMask()) { 2586e8f7d59SQuinn Dawkins VectorType maskType = read.getMaskType(); 2596e8f7d59SQuinn Dawkins mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(), 2606e8f7d59SQuinn Dawkins newType, newMap, maskType); 261796d48b0SQuinn Dawkins } 262796d48b0SQuinn Dawkins 26399ef9eebSMatthias Springer auto newRead = rewriter.create<vector::TransferReadOp>( 2647c38fd60SJacques Pienaar read.getLoc(), newType, read.getSource(), read.getIndices(), 265796d48b0SQuinn Dawkins AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); 26699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 26799ef9eebSMatthias Springer 26899ef9eebSMatthias Springer return success(); 26999ef9eebSMatthias Springer } 27099ef9eebSMatthias Springer }; 27199ef9eebSMatthias Springer 27299ef9eebSMatthias Springer // Turns vector.transfer_write on vector with leading 1 dimensions into 27399ef9eebSMatthias Springer // vector.shape_cast followed by vector.transfer_write on vector without leading 27499ef9eebSMatthias Springer // 1 dimensions. 27599ef9eebSMatthias Springer struct CastAwayTransferWriteLeadingOneDim 27699ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> { 27799ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 27899ef9eebSMatthias Springer 27999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp write, 28099ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 281dedc7d4dSJerry Wu // TODO(#78787): Not supported masked op yet. 282dedc7d4dSJerry Wu if (cast<MaskableOpInterface>(write.getOperation()).isMasked()) 283dedc7d4dSJerry Wu return failure(); 28499ef9eebSMatthias Springer // TODO: support 0-d corner case. 28599ef9eebSMatthias Springer if (write.getTransferRank() == 0) 28699ef9eebSMatthias Springer return failure(); 28799ef9eebSMatthias Springer 2885550c821STres Popp auto shapedType = dyn_cast<ShapedType>(write.getSource().getType()); 28999ef9eebSMatthias Springer if (shapedType.getElementType() != write.getVectorType().getElementType()) 29099ef9eebSMatthias Springer return failure(); 29199ef9eebSMatthias Springer 29299ef9eebSMatthias Springer VectorType oldType = write.getVectorType(); 29399ef9eebSMatthias Springer VectorType newType = trimLeadingOneDims(oldType); 29499ef9eebSMatthias Springer if (newType == oldType) 29599ef9eebSMatthias Springer return failure(); 29699ef9eebSMatthias Springer int64_t dropDim = oldType.getRank() - newType.getRank(); 29799ef9eebSMatthias Springer 2987c38fd60SJacques Pienaar AffineMap oldMap = write.getPermutationMap(); 29999ef9eebSMatthias Springer ArrayRef<AffineExpr> newResults = 30099ef9eebSMatthias Springer oldMap.getResults().take_back(newType.getRank()); 30199ef9eebSMatthias Springer AffineMap newMap = 30299ef9eebSMatthias Springer AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 30399ef9eebSMatthias Springer rewriter.getContext()); 30499ef9eebSMatthias Springer 30599ef9eebSMatthias Springer ArrayAttr inBoundsAttr; 3067c38fd60SJacques Pienaar if (write.getInBounds()) 30799ef9eebSMatthias Springer inBoundsAttr = rewriter.getArrayAttr( 3087c38fd60SJacques Pienaar write.getInBoundsAttr().getValue().take_back(newType.getRank())); 30999ef9eebSMatthias Springer 31099ef9eebSMatthias Springer auto newVector = rewriter.create<vector::ExtractOp>( 3117c38fd60SJacques Pienaar write.getLoc(), write.getVector(), splatZero(dropDim)); 312796d48b0SQuinn Dawkins 313796d48b0SQuinn Dawkins if (write.getMask()) { 3146e8f7d59SQuinn Dawkins VectorType maskType = write.getMaskType(); 3156e8f7d59SQuinn Dawkins Value newMask = dropUnitDimsFromMask( 3166e8f7d59SQuinn Dawkins rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType); 317796d48b0SQuinn Dawkins rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 318796d48b0SQuinn Dawkins write, newVector, write.getSource(), write.getIndices(), 319796d48b0SQuinn Dawkins AffineMapAttr::get(newMap), newMask, inBoundsAttr); 320796d48b0SQuinn Dawkins return success(); 321796d48b0SQuinn Dawkins } 322796d48b0SQuinn Dawkins 32399ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 3247c38fd60SJacques Pienaar write, newVector, write.getSource(), write.getIndices(), 32599ef9eebSMatthias Springer AffineMapAttr::get(newMap), inBoundsAttr); 32699ef9eebSMatthias Springer return success(); 32799ef9eebSMatthias Springer } 32899ef9eebSMatthias Springer }; 32999ef9eebSMatthias Springer 330eca7698aSLei Zhang } // namespace 331ad9b5a4bSNirvedh 3325f1b2cffSAndrzej Warzyński FailureOr<Value> 333eca7698aSLei Zhang mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, 3345f1b2cffSAndrzej Warzyński MaskingOpInterface maskingOp, 335eca7698aSLei Zhang RewriterBase &rewriter) { 3365550c821STres Popp VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); 337ad9b5a4bSNirvedh if (oldAccType == nullptr) 338ad9b5a4bSNirvedh return failure(); 339ad9b5a4bSNirvedh if (oldAccType.getRank() < 2) 340ad9b5a4bSNirvedh return failure(); 341ad9b5a4bSNirvedh if (oldAccType.getShape()[0] != 1) 342ad9b5a4bSNirvedh return failure(); 343ad9b5a4bSNirvedh // currently we support only dropping one dim but the pattern can be applied 344ad9b5a4bSNirvedh // greedily to drop more. 345ad9b5a4bSNirvedh int64_t dropDim = 1; 346ad9b5a4bSNirvedh 347d2c0572bSJacques Pienaar auto oldIndexingMaps = contractOp.getIndexingMapsArray(); 348ad9b5a4bSNirvedh SmallVector<AffineMap> newIndexingMaps; 349ad9b5a4bSNirvedh 3507c38fd60SJacques Pienaar auto oldIteratorTypes = contractOp.getIteratorTypes(); 351ad9b5a4bSNirvedh SmallVector<Attribute> newIteratorTypes; 352ad9b5a4bSNirvedh 353ad9b5a4bSNirvedh int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 354ad9b5a4bSNirvedh 355ad9b5a4bSNirvedh if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 356ad9b5a4bSNirvedh // only parallel type iterators can be dropped. 357ad9b5a4bSNirvedh return failure(); 358ad9b5a4bSNirvedh 359ad9b5a4bSNirvedh for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 360ad9b5a4bSNirvedh int64_t currDim = it.index(); 361ad9b5a4bSNirvedh if (currDim == dimToDrop) 362ad9b5a4bSNirvedh continue; 363ad9b5a4bSNirvedh newIteratorTypes.push_back(it.value()); 364ad9b5a4bSNirvedh } 365ad9b5a4bSNirvedh 3667c38fd60SJacques Pienaar SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), 3677c38fd60SJacques Pienaar contractOp.getAcc()}; 368ad9b5a4bSNirvedh SmallVector<Value> newOperands; 3695f1b2cffSAndrzej Warzyński auto loc = contractOp.getLoc(); 370ad9b5a4bSNirvedh 371ad9b5a4bSNirvedh for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 372ad9b5a4bSNirvedh // Check if the dim to be dropped exists as a leading dim in the operand 373ad9b5a4bSNirvedh // if it does then we use vector.extract to drop it. 374ad9b5a4bSNirvedh bool validExtract = false; 375ad9b5a4bSNirvedh SmallVector<AffineExpr> results; 376ad9b5a4bSNirvedh auto map = it.value(); 377ad9b5a4bSNirvedh int64_t orginalZeroDim = it.value().getDimPosition(0); 378ad9b5a4bSNirvedh if (orginalZeroDim != dimToDrop) { 379ad9b5a4bSNirvedh // There are two reasons to be in this path, 1. We need to 380*aa295216SJay Foad // transpose the operand to make the dim to be dropped 381ad9b5a4bSNirvedh // leading. 2. The dim to be dropped does not exist and in 382*aa295216SJay Foad // that case we dont want to add a unit transpose but we must 383ad9b5a4bSNirvedh // check all the indices to make sure this is the case. 384*aa295216SJay Foad bool transposeNeeded = false; 385ad9b5a4bSNirvedh SmallVector<int64_t> perm; 386ad9b5a4bSNirvedh SmallVector<AffineExpr> transposeResults; 387ad9b5a4bSNirvedh 388ad9b5a4bSNirvedh for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 389ad9b5a4bSNirvedh int64_t currDim = map.getDimPosition(i); 390ad9b5a4bSNirvedh if (currDim == dimToDrop) { 391*aa295216SJay Foad transposeNeeded = true; 392ad9b5a4bSNirvedh perm.insert(perm.begin(), i); 393ad9b5a4bSNirvedh auto targetExpr = rewriter.getAffineDimExpr(currDim); 394ad9b5a4bSNirvedh transposeResults.insert(transposeResults.begin(), targetExpr); 395ad9b5a4bSNirvedh } else { 396ad9b5a4bSNirvedh perm.push_back(i); 397ad9b5a4bSNirvedh auto targetExpr = rewriter.getAffineDimExpr(currDim); 398ad9b5a4bSNirvedh transposeResults.push_back(targetExpr); 399ad9b5a4bSNirvedh } 400ad9b5a4bSNirvedh } 40166fed33dSKojo Acquah 40266fed33dSKojo Acquah // Checks if only the outer, unit dimensions (of size 1) are permuted. 40366fed33dSKojo Acquah // Such transposes do not materially effect the underlying vector and can 40466fed33dSKojo Acquah // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> 40566fed33dSKojo Acquah bool transposeNonOuterUnitDims = false; 406a5757c5bSChristian Sigg auto operandShape = cast<ShapedType>(operands[it.index()].getType()); 40766fed33dSKojo Acquah for (auto [index, dim] : 40866fed33dSKojo Acquah llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) { 40966fed33dSKojo Acquah if (dim != static_cast<int64_t>(index) && 41066fed33dSKojo Acquah operandShape.getDimSize(index) != 1) { 41166fed33dSKojo Acquah transposeNonOuterUnitDims = true; 41266fed33dSKojo Acquah break; 41366fed33dSKojo Acquah } 41466fed33dSKojo Acquah } 41566fed33dSKojo Acquah 416*aa295216SJay Foad // Do the transpose now if needed so that we can drop the 417ad9b5a4bSNirvedh // correct dim using extract later. 418*aa295216SJay Foad if (transposeNeeded) { 419ad9b5a4bSNirvedh map = AffineMap::get(map.getNumDims(), 0, transposeResults, 420ad9b5a4bSNirvedh contractOp.getContext()); 42166fed33dSKojo Acquah if (transposeNonOuterUnitDims) { 42266fed33dSKojo Acquah operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>( 4235f1b2cffSAndrzej Warzyński loc, operands[it.index()], perm); 424ad9b5a4bSNirvedh } 425ad9b5a4bSNirvedh } 42666fed33dSKojo Acquah } 427ad9b5a4bSNirvedh // We have taken care to have the dim to be dropped be 428ad9b5a4bSNirvedh // the leading dim. If its still not leading that means it 429ad9b5a4bSNirvedh // does not exist in this operand and hence we do not need 430ad9b5a4bSNirvedh // an extract. 431ad9b5a4bSNirvedh if (map.getDimPosition(0) == dimToDrop) 432ad9b5a4bSNirvedh validExtract = true; 433ad9b5a4bSNirvedh 434ad9b5a4bSNirvedh for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 435ad9b5a4bSNirvedh int64_t currDim = map.getDimPosition(i); 436ad9b5a4bSNirvedh if (currDim == dimToDrop) 437ad9b5a4bSNirvedh // This is the dim we are dropping. 438ad9b5a4bSNirvedh continue; 439ad9b5a4bSNirvedh auto targetExpr = rewriter.getAffineDimExpr( 440ad9b5a4bSNirvedh currDim < dimToDrop ? currDim : currDim - 1); 441ad9b5a4bSNirvedh results.push_back(targetExpr); 442ad9b5a4bSNirvedh } 443ad9b5a4bSNirvedh newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 444ad9b5a4bSNirvedh contractOp.getContext())); 445ad9b5a4bSNirvedh // Extract if its a valid extraction, otherwise use the operand 446ad9b5a4bSNirvedh // without extraction. 447eca7698aSLei Zhang newOperands.push_back( 4485f1b2cffSAndrzej Warzyński validExtract ? rewriter.create<vector::ExtractOp>( 4495f1b2cffSAndrzej Warzyński loc, operands[it.index()], splatZero(dropDim)) 450ad9b5a4bSNirvedh : operands[it.index()]); 451ad9b5a4bSNirvedh } 4525f1b2cffSAndrzej Warzyński 4535f1b2cffSAndrzej Warzyński // Depending on whether this vector.contract is masked, the replacing Op 4545f1b2cffSAndrzej Warzyński // should either be a new vector.contract Op or vector.mask Op. 4555f1b2cffSAndrzej Warzyński Operation *newOp = rewriter.create<vector::ContractionOp>( 4565f1b2cffSAndrzej Warzyński loc, newOperands[0], newOperands[1], newOperands[2], 457ad9b5a4bSNirvedh rewriter.getAffineMapArrayAttr(newIndexingMaps), 4587c38fd60SJacques Pienaar rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); 4595f1b2cffSAndrzej Warzyński 4605f1b2cffSAndrzej Warzyński if (maskingOp) { 4615f1b2cffSAndrzej Warzyński auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(), 4625f1b2cffSAndrzej Warzyński splatZero(dropDim)); 4635f1b2cffSAndrzej Warzyński 4645f1b2cffSAndrzej Warzyński newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); 4655f1b2cffSAndrzej Warzyński } 4665f1b2cffSAndrzej Warzyński 4675f1b2cffSAndrzej Warzyński return rewriter 4685f1b2cffSAndrzej Warzyński .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0], 4695f1b2cffSAndrzej Warzyński newOp->getResults()[0]) 4705f1b2cffSAndrzej Warzyński .getResult(); 471ad9b5a4bSNirvedh } 472eca7698aSLei Zhang 473eca7698aSLei Zhang namespace { 474eca7698aSLei Zhang 475eca7698aSLei Zhang /// Turns vector.contract on vector with leading 1 dimensions into 476eca7698aSLei Zhang /// vector.extract followed by vector.contract on vector without leading 477*aa295216SJay Foad /// 1 dimensions. Also performs transpose of lhs and rhs operands if required 478eca7698aSLei Zhang /// prior to extract. 479eca7698aSLei Zhang struct CastAwayContractionLeadingOneDim 4805f1b2cffSAndrzej Warzyński : public MaskableOpRewritePattern<vector::ContractionOp> { 4815f1b2cffSAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 482eca7698aSLei Zhang 4835f1b2cffSAndrzej Warzyński FailureOr<Value> 4845f1b2cffSAndrzej Warzyński matchAndRewriteMaskableOp(vector::ContractionOp contractOp, 4855f1b2cffSAndrzej Warzyński MaskingOpInterface maskingOp, 486eca7698aSLei Zhang PatternRewriter &rewriter) const override { 4875f1b2cffSAndrzej Warzyński return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter); 488eca7698aSLei Zhang } 489ad9b5a4bSNirvedh }; 490ad9b5a4bSNirvedh 491f9070b2dSAndrzej Warzynski /// Looks at elementwise operations on vectors with at least one leading 492f9070b2dSAndrzej Warzynski /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), 493f9070b2dSAndrzej Warzynski /// and cast aways the leading one dimensions (_plural_) and then broadcasts 494f9070b2dSAndrzej Warzynski /// the results. 495f9070b2dSAndrzej Warzynski /// 496f9070b2dSAndrzej Warzynski /// Example before: 497f9070b2dSAndrzej Warzynski /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> 498f9070b2dSAndrzej Warzynski /// Example after: 499f9070b2dSAndrzej Warzynski /// %2 = arith.mulf %0, %1 : vector<4x1xf32> 500f9070b2dSAndrzej Warzynski /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> 501f9070b2dSAndrzej Warzynski /// 502f9070b2dSAndrzej Warzynski /// Does support scalable vectors. 50399ef9eebSMatthias Springer class CastAwayElementwiseLeadingOneDim : public RewritePattern { 50499ef9eebSMatthias Springer public: 50527cc31b6SNicolas Vasilache CastAwayElementwiseLeadingOneDim(MLIRContext *context, 50627cc31b6SNicolas Vasilache PatternBenefit benefit = 1) 50727cc31b6SNicolas Vasilache : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} 50899ef9eebSMatthias Springer 50999ef9eebSMatthias Springer LogicalResult matchAndRewrite(Operation *op, 51099ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 51199ef9eebSMatthias Springer if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 51299ef9eebSMatthias Springer return failure(); 5135550c821STres Popp auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]); 51499ef9eebSMatthias Springer if (!vecType) 51599ef9eebSMatthias Springer return failure(); 51699ef9eebSMatthias Springer VectorType newVecType = trimLeadingOneDims(vecType); 51799ef9eebSMatthias Springer if (newVecType == vecType) 51899ef9eebSMatthias Springer return failure(); 51999ef9eebSMatthias Springer int64_t dropDim = vecType.getRank() - newVecType.getRank(); 52099ef9eebSMatthias Springer SmallVector<Value, 4> newOperands; 52199ef9eebSMatthias Springer for (Value operand : op->getOperands()) { 5225550c821STres Popp if (auto opVecType = dyn_cast<VectorType>(operand.getType())) { 52399ef9eebSMatthias Springer newOperands.push_back(rewriter.create<vector::ExtractOp>( 52499ef9eebSMatthias Springer op->getLoc(), operand, splatZero(dropDim))); 52599ef9eebSMatthias Springer } else { 52699ef9eebSMatthias Springer newOperands.push_back(operand); 52799ef9eebSMatthias Springer } 52899ef9eebSMatthias Springer } 52914ecafd0SChia-hung Duan Operation *newOp = 53014ecafd0SChia-hung Duan rewriter.create(op->getLoc(), op->getName().getIdentifier(), 53114ecafd0SChia-hung Duan newOperands, newVecType, op->getAttrs()); 53299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 53399ef9eebSMatthias Springer newOp->getResult(0)); 53499ef9eebSMatthias Springer return success(); 53599ef9eebSMatthias Springer } 53699ef9eebSMatthias Springer }; 53799ef9eebSMatthias Springer 538796d48b0SQuinn Dawkins // Drops leading 1 dimensions from vector.constant_mask and inserts a 539796d48b0SQuinn Dawkins // vector.broadcast back to the original shape. 540796d48b0SQuinn Dawkins struct CastAwayConstantMaskLeadingOneDim 541796d48b0SQuinn Dawkins : public OpRewritePattern<vector::ConstantMaskOp> { 542796d48b0SQuinn Dawkins using OpRewritePattern::OpRewritePattern; 543796d48b0SQuinn Dawkins 544796d48b0SQuinn Dawkins LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, 545796d48b0SQuinn Dawkins PatternRewriter &rewriter) const override { 546796d48b0SQuinn Dawkins VectorType oldType = mask.getType(); 547796d48b0SQuinn Dawkins VectorType newType = trimLeadingOneDims(oldType); 548796d48b0SQuinn Dawkins 549796d48b0SQuinn Dawkins if (newType == oldType) 550796d48b0SQuinn Dawkins return failure(); 551796d48b0SQuinn Dawkins 552796d48b0SQuinn Dawkins int64_t dropDim = oldType.getRank() - newType.getRank(); 5530d9b4394SBenjamin Maxwell ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes(); 554796d48b0SQuinn Dawkins 555796d48b0SQuinn Dawkins // If any of the dropped unit dims has a size of `0`, the entire mask is a 556796d48b0SQuinn Dawkins // zero mask, else the unit dim has no effect on the mask. 557796d48b0SQuinn Dawkins int64_t flatLeadingSize = 558796d48b0SQuinn Dawkins std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, 559796d48b0SQuinn Dawkins static_cast<int64_t>(1), std::multiplies<int64_t>()); 5609cbc1f29SHan-Chung Wang SmallVector<int64_t> newDimSizes = {flatLeadingSize}; 561796d48b0SQuinn Dawkins newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); 562796d48b0SQuinn Dawkins 563796d48b0SQuinn Dawkins auto newMask = rewriter.create<vector::ConstantMaskOp>( 5640d9b4394SBenjamin Maxwell mask.getLoc(), newType, newDimSizes); 565796d48b0SQuinn Dawkins rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask); 566796d48b0SQuinn Dawkins return success(); 567796d48b0SQuinn Dawkins } 568796d48b0SQuinn Dawkins }; 569796d48b0SQuinn Dawkins 57099ef9eebSMatthias Springer } // namespace 57199ef9eebSMatthias Springer 57299ef9eebSMatthias Springer void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 57327cc31b6SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 574ad9b5a4bSNirvedh patterns 575ad9b5a4bSNirvedh .add<CastAwayExtractStridedSliceLeadingOneDim, 576e54236dfSLei Zhang CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, 577796d48b0SQuinn Dawkins CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim, 578ad9b5a4bSNirvedh CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 57927cc31b6SNicolas Vasilache CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); 58027cc31b6SNicolas Vasilache populateShapeCastFoldingPatterns(patterns, benefit); 58199ef9eebSMatthias Springer } 582