xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (revision 099fd018d1b04013ef46c0e26ed008585ab8bcbb)
12bc4c3e9SNicolas Vasilache //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
22bc4c3e9SNicolas Vasilache //
32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62bc4c3e9SNicolas Vasilache //
72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
82bc4c3e9SNicolas Vasilache //
92bc4c3e9SNicolas Vasilache // This file implements rewrite patterns for the permutation_map attribute of
102bc4c3e9SNicolas Vasilache // vector.transfer operations.
112bc4c3e9SNicolas Vasilache //
122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
132bc4c3e9SNicolas Vasilache 
142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
182bc4c3e9SNicolas Vasilache #include "mlir/Interfaces/VectorInterfaces.h"
192bc4c3e9SNicolas Vasilache 
202bc4c3e9SNicolas Vasilache using namespace mlir;
212bc4c3e9SNicolas Vasilache using namespace mlir::vector;
222bc4c3e9SNicolas Vasilache 
232bc4c3e9SNicolas Vasilache /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
242bc4c3e9SNicolas Vasilache /// permutation based on the given indices.
252bc4c3e9SNicolas Vasilache static ArrayAttr
262bc4c3e9SNicolas Vasilache inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
272bc4c3e9SNicolas Vasilache                              const SmallVector<unsigned> &permutation) {
282bc4c3e9SNicolas Vasilache   SmallVector<bool> newInBoundsValues(permutation.size());
292bc4c3e9SNicolas Vasilache   size_t index = 0;
302bc4c3e9SNicolas Vasilache   for (unsigned pos : permutation)
312bc4c3e9SNicolas Vasilache     newInBoundsValues[pos] =
325550c821STres Popp         cast<BoolAttr>(attr.getValue()[index++]).getValue();
332bc4c3e9SNicolas Vasilache   return builder.getBoolArrayAttr(newInBoundsValues);
342bc4c3e9SNicolas Vasilache }
352bc4c3e9SNicolas Vasilache 
362bc4c3e9SNicolas Vasilache /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
372bc4c3e9SNicolas Vasilache /// dimensions.
382bc4c3e9SNicolas Vasilache static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
392bc4c3e9SNicolas Vasilache                               int64_t addedRank) {
405550c821STres Popp   auto originalVecType = cast<VectorType>(vec.getType());
412bc4c3e9SNicolas Vasilache   SmallVector<int64_t> newShape(addedRank, 1);
422bc4c3e9SNicolas Vasilache   newShape.append(originalVecType.getShape().begin(),
432bc4c3e9SNicolas Vasilache                   originalVecType.getShape().end());
44465ea0bfSCrefeda Rodrigues 
45465ea0bfSCrefeda Rodrigues   SmallVector<bool> newScalableDims(addedRank, false);
46465ea0bfSCrefeda Rodrigues   newScalableDims.append(originalVecType.getScalableDims().begin(),
47465ea0bfSCrefeda Rodrigues                          originalVecType.getScalableDims().end());
48465ea0bfSCrefeda Rodrigues   VectorType newVecType = VectorType::get(
49465ea0bfSCrefeda Rodrigues       newShape, originalVecType.getElementType(), newScalableDims);
502bc4c3e9SNicolas Vasilache   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
512bc4c3e9SNicolas Vasilache }
522bc4c3e9SNicolas Vasilache 
53a7a5641bSMatthias Springer /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
54a7a5641bSMatthias Springer /// dimensions.
55a7a5641bSMatthias Springer static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
56a7a5641bSMatthias Springer                             int64_t addedRank) {
57a7a5641bSMatthias Springer   Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
58a7a5641bSMatthias Springer   SmallVector<int64_t> permutation;
59a7a5641bSMatthias Springer   for (int64_t i = addedRank,
60a5757c5bSChristian Sigg                e = cast<VectorType>(broadcasted.getType()).getRank();
61a7a5641bSMatthias Springer        i < e; ++i)
62a7a5641bSMatthias Springer     permutation.push_back(i);
63a7a5641bSMatthias Springer   for (int64_t i = 0; i < addedRank; ++i)
64a7a5641bSMatthias Springer     permutation.push_back(i);
65a7a5641bSMatthias Springer   return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
66a7a5641bSMatthias Springer }
67a7a5641bSMatthias Springer 
682bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
692bc4c3e9SNicolas Vasilache // populateVectorTransferPermutationMapLoweringPatterns
702bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
712bc4c3e9SNicolas Vasilache 
722bc4c3e9SNicolas Vasilache namespace {
732bc4c3e9SNicolas Vasilache /// Lower transfer_read op with permutation into a transfer_read with a
742bc4c3e9SNicolas Vasilache /// permutation map composed of leading zeros followed by a minor identiy +
752bc4c3e9SNicolas Vasilache /// vector.transpose op.
762bc4c3e9SNicolas Vasilache /// Ex:
772bc4c3e9SNicolas Vasilache ///     vector.transfer_read ...
782bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2) -> (0, d1)
792bc4c3e9SNicolas Vasilache /// into:
802bc4c3e9SNicolas Vasilache ///     %v = vector.transfer_read ...
812bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2) -> (d1, 0)
822bc4c3e9SNicolas Vasilache ///     vector.transpose %v, [1, 0]
832bc4c3e9SNicolas Vasilache ///
842bc4c3e9SNicolas Vasilache ///     vector.transfer_read ...
852bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
862bc4c3e9SNicolas Vasilache /// into:
872bc4c3e9SNicolas Vasilache ///     %v = vector.transfer_read ...
882bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
892bc4c3e9SNicolas Vasilache ///     vector.transpose %v, [0, 1, 3, 2, 4]
902bc4c3e9SNicolas Vasilache /// Note that an alternative is to transform it to linalg.transpose +
912bc4c3e9SNicolas Vasilache /// vector.transfer_read to do the transpose in memory instead.
922bc4c3e9SNicolas Vasilache struct TransferReadPermutationLowering
93fdd245adSHugo Trachino     : public MaskableOpRewritePattern<vector::TransferReadOp> {
94fdd245adSHugo Trachino   using MaskableOpRewritePattern::MaskableOpRewritePattern;
952bc4c3e9SNicolas Vasilache 
96fdd245adSHugo Trachino   FailureOr<mlir::Value>
97fdd245adSHugo Trachino   matchAndRewriteMaskableOp(vector::TransferReadOp op,
98fdd245adSHugo Trachino                             MaskingOpInterface maskOp,
992bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
1002bc4c3e9SNicolas Vasilache     // TODO: support 0-d corner case.
1012bc4c3e9SNicolas Vasilache     if (op.getTransferRank() == 0)
1028b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103fdd245adSHugo Trachino     // TODO: Support transfer_read inside MaskOp case.
104fdd245adSHugo Trachino     if (maskOp)
105fdd245adSHugo Trachino       return rewriter.notifyMatchFailure(op, "Masked case not supported");
1062bc4c3e9SNicolas Vasilache 
1072bc4c3e9SNicolas Vasilache     SmallVector<unsigned> permutation;
1082bc4c3e9SNicolas Vasilache     AffineMap map = op.getPermutationMap();
1092bc4c3e9SNicolas Vasilache     if (map.getNumResults() == 0)
1108b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "0 result permutation map");
1118b513407SNicolas Vasilache     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
1128b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(
1138b513407SNicolas Vasilache           op, "map is not permutable to minor identity, apply another pattern");
1148b513407SNicolas Vasilache     }
1152bc4c3e9SNicolas Vasilache     AffineMap permutationMap =
1162bc4c3e9SNicolas Vasilache         map.getPermutationMap(permutation, op.getContext());
1172bc4c3e9SNicolas Vasilache     if (permutationMap.isIdentity())
1188b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "map is not identity");
1192bc4c3e9SNicolas Vasilache 
1202bc4c3e9SNicolas Vasilache     permutationMap = map.getPermutationMap(permutation, op.getContext());
1212bc4c3e9SNicolas Vasilache     // Caluclate the map of the new read by applying the inverse permutation.
1222bc4c3e9SNicolas Vasilache     permutationMap = inversePermutation(permutationMap);
1232bc4c3e9SNicolas Vasilache     AffineMap newMap = permutationMap.compose(map);
1242bc4c3e9SNicolas Vasilache     // Apply the reverse transpose to deduce the type of the transfer_read.
1252bc4c3e9SNicolas Vasilache     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
1262bc4c3e9SNicolas Vasilache     SmallVector<int64_t> newVectorShape(originalShape.size());
12712b49518SAndrzej Warzynski     ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
12812b49518SAndrzej Warzynski     SmallVector<bool> newScalableDims(originalShape.size());
1292bc4c3e9SNicolas Vasilache     for (const auto &pos : llvm::enumerate(permutation)) {
1302bc4c3e9SNicolas Vasilache       newVectorShape[pos.value()] = originalShape[pos.index()];
13112b49518SAndrzej Warzynski       newScalableDims[pos.value()] = originalScalableDims[pos.index()];
1322bc4c3e9SNicolas Vasilache     }
1332bc4c3e9SNicolas Vasilache 
1342bc4c3e9SNicolas Vasilache     // Transpose in_bounds attribute.
1352bc4c3e9SNicolas Vasilache     ArrayAttr newInBoundsAttr =
1362ee5586aSAndrzej Warzyński         inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
1372bc4c3e9SNicolas Vasilache 
1382bc4c3e9SNicolas Vasilache     // Generate new transfer_read operation.
13912b49518SAndrzej Warzynski     VectorType newReadType = VectorType::get(
14012b49518SAndrzej Warzynski         newVectorShape, op.getVectorType().getElementType(), newScalableDims);
1412bc4c3e9SNicolas Vasilache     Value newRead = rewriter.create<vector::TransferReadOp>(
1422bc4c3e9SNicolas Vasilache         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
1432bc4c3e9SNicolas Vasilache         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
1442bc4c3e9SNicolas Vasilache         newInBoundsAttr);
1452bc4c3e9SNicolas Vasilache 
1462bc4c3e9SNicolas Vasilache     // Transpose result of transfer_read.
1472bc4c3e9SNicolas Vasilache     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
148fdd245adSHugo Trachino     return rewriter
149fdd245adSHugo Trachino         .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
150fdd245adSHugo Trachino         .getResult();
1512bc4c3e9SNicolas Vasilache   }
1522bc4c3e9SNicolas Vasilache };
1532bc4c3e9SNicolas Vasilache 
1542bc4c3e9SNicolas Vasilache /// Lower transfer_write op with permutation into a transfer_write with a
1552bc4c3e9SNicolas Vasilache /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
1562bc4c3e9SNicolas Vasilache /// Ex:
1572bc4c3e9SNicolas Vasilache ///     vector.transfer_write %v ...
1582bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
1592bc4c3e9SNicolas Vasilache /// into:
1602bc4c3e9SNicolas Vasilache ///     %tmp = vector.transpose %v, [2, 0, 1]
1612bc4c3e9SNicolas Vasilache ///     vector.transfer_write %tmp ...
1622bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
1632bc4c3e9SNicolas Vasilache ///
1642bc4c3e9SNicolas Vasilache ///     vector.transfer_write %v ...
1652bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
1662bc4c3e9SNicolas Vasilache /// into:
1672bc4c3e9SNicolas Vasilache ///     %tmp = vector.transpose %v, [1, 0]
1682bc4c3e9SNicolas Vasilache ///     %v = vector.transfer_write %tmp ...
1692bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
1702bc4c3e9SNicolas Vasilache struct TransferWritePermutationLowering
171fdd245adSHugo Trachino     : public MaskableOpRewritePattern<vector::TransferWriteOp> {
172fdd245adSHugo Trachino   using MaskableOpRewritePattern::MaskableOpRewritePattern;
1732bc4c3e9SNicolas Vasilache 
174fdd245adSHugo Trachino   FailureOr<mlir::Value>
175fdd245adSHugo Trachino   matchAndRewriteMaskableOp(vector::TransferWriteOp op,
176fdd245adSHugo Trachino                             MaskingOpInterface maskOp,
1772bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
1782bc4c3e9SNicolas Vasilache     // TODO: support 0-d corner case.
1792bc4c3e9SNicolas Vasilache     if (op.getTransferRank() == 0)
1808b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
181fdd245adSHugo Trachino     // TODO: Support transfer_write inside MaskOp case.
182fdd245adSHugo Trachino     if (maskOp)
183fdd245adSHugo Trachino       return rewriter.notifyMatchFailure(op, "Masked case not supported");
1842bc4c3e9SNicolas Vasilache 
1852bc4c3e9SNicolas Vasilache     SmallVector<unsigned> permutation;
1862bc4c3e9SNicolas Vasilache     AffineMap map = op.getPermutationMap();
1872bc4c3e9SNicolas Vasilache     if (map.isMinorIdentity())
1888b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "map is already minor identity");
1898b513407SNicolas Vasilache 
1908b513407SNicolas Vasilache     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
1918b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(
1928b513407SNicolas Vasilache           op, "map is not permutable to minor identity, apply another pattern");
1938b513407SNicolas Vasilache     }
1942bc4c3e9SNicolas Vasilache 
1952bc4c3e9SNicolas Vasilache     // Remove unused dims from the permutation map. E.g.:
1962bc4c3e9SNicolas Vasilache     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
1972bc4c3e9SNicolas Vasilache     // comp = (d0, d1, d2) -> (d2, d0, d1)
1982bc4c3e9SNicolas Vasilache     auto comp = compressUnusedDims(map);
1992bc4c3e9SNicolas Vasilache     AffineMap permutationMap = inversePermutation(comp);
2002bc4c3e9SNicolas Vasilache     // Get positions of remaining result dims.
2012bc4c3e9SNicolas Vasilache     SmallVector<int64_t> indices;
2022bc4c3e9SNicolas Vasilache     llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
2032bc4c3e9SNicolas Vasilache                     [](AffineExpr expr) {
2041609f1c2Slong.chen                       return dyn_cast<AffineDimExpr>(expr).getPosition();
2052bc4c3e9SNicolas Vasilache                     });
2062bc4c3e9SNicolas Vasilache 
2072bc4c3e9SNicolas Vasilache     // Transpose in_bounds attribute.
2082bc4c3e9SNicolas Vasilache     ArrayAttr newInBoundsAttr =
2092ee5586aSAndrzej Warzyński         inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
2102bc4c3e9SNicolas Vasilache 
2112bc4c3e9SNicolas Vasilache     // Generate new transfer_write operation.
2122bc4c3e9SNicolas Vasilache     Value newVec = rewriter.create<vector::TransposeOp>(
2132bc4c3e9SNicolas Vasilache         op.getLoc(), op.getVector(), indices);
2142bc4c3e9SNicolas Vasilache     auto newMap = AffineMap::getMinorIdentityMap(
2152bc4c3e9SNicolas Vasilache         map.getNumDims(), map.getNumResults(), rewriter.getContext());
216fdd245adSHugo Trachino     auto newWrite = rewriter.create<vector::TransferWriteOp>(
217fdd245adSHugo Trachino         op.getLoc(), newVec, op.getSource(), op.getIndices(),
218fdd245adSHugo Trachino         AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
219fdd245adSHugo Trachino     if (newWrite.hasPureTensorSemantics())
220fdd245adSHugo Trachino       return newWrite.getResult();
221fdd245adSHugo Trachino     // In the memref case there's no return value. Use empty value to signal
222fdd245adSHugo Trachino     // success.
223fdd245adSHugo Trachino     return Value();
2242bc4c3e9SNicolas Vasilache   }
2252bc4c3e9SNicolas Vasilache };
2262bc4c3e9SNicolas Vasilache 
2272bc4c3e9SNicolas Vasilache /// Convert a transfer.write op with a map which isn't the permutation of a
2282bc4c3e9SNicolas Vasilache /// minor identity into a vector.broadcast + transfer_write with permutation of
2292bc4c3e9SNicolas Vasilache /// minor identity map by adding unit dim on inner dimension. Ex:
2302bc4c3e9SNicolas Vasilache /// ```
2312bc4c3e9SNicolas Vasilache ///   vector.transfer_write %v
2322bc4c3e9SNicolas Vasilache ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
2332bc4c3e9SNicolas Vasilache ///     vector<8x16xf32>
2342bc4c3e9SNicolas Vasilache /// ```
2352bc4c3e9SNicolas Vasilache /// into:
2362bc4c3e9SNicolas Vasilache /// ```
2372bc4c3e9SNicolas Vasilache ///   %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
2382bc4c3e9SNicolas Vasilache ///   vector.transfer_write %v1
2392bc4c3e9SNicolas Vasilache ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
2402bc4c3e9SNicolas Vasilache ///     vector<1x8x16xf32>
2412bc4c3e9SNicolas Vasilache /// ```
2422bc4c3e9SNicolas Vasilache struct TransferWriteNonPermutationLowering
243fdd245adSHugo Trachino     : public MaskableOpRewritePattern<vector::TransferWriteOp> {
244fdd245adSHugo Trachino   using MaskableOpRewritePattern::MaskableOpRewritePattern;
2452bc4c3e9SNicolas Vasilache 
246fdd245adSHugo Trachino   FailureOr<mlir::Value>
247fdd245adSHugo Trachino   matchAndRewriteMaskableOp(vector::TransferWriteOp op,
248fdd245adSHugo Trachino                             MaskingOpInterface maskOp,
2492bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
2508b513407SNicolas Vasilache     // TODO: support 0-d corner case.
2512bc4c3e9SNicolas Vasilache     if (op.getTransferRank() == 0)
2528b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
253fdd245adSHugo Trachino     // TODO: Support transfer_write inside MaskOp case.
254fdd245adSHugo Trachino     if (maskOp)
255fdd245adSHugo Trachino       return rewriter.notifyMatchFailure(op, "Masked case not supported");
2568b513407SNicolas Vasilache 
2572bc4c3e9SNicolas Vasilache     SmallVector<unsigned> permutation;
2582bc4c3e9SNicolas Vasilache     AffineMap map = op.getPermutationMap();
2598b513407SNicolas Vasilache     if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
2608b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(
2618b513407SNicolas Vasilache           op,
2628b513407SNicolas Vasilache           "map is already permutable to minor identity, apply another pattern");
2638b513407SNicolas Vasilache     }
2642bc4c3e9SNicolas Vasilache 
2652bc4c3e9SNicolas Vasilache     // Missing outer dimensions are allowed, find the most outer existing
2662bc4c3e9SNicolas Vasilache     // dimension then deduce the missing inner dimensions.
2672bc4c3e9SNicolas Vasilache     SmallVector<bool> foundDim(map.getNumDims(), false);
2688b513407SNicolas Vasilache     for (AffineExpr exp : map.getResults())
2691609f1c2Slong.chen       foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
2702bc4c3e9SNicolas Vasilache     SmallVector<AffineExpr> exprs;
2712bc4c3e9SNicolas Vasilache     bool foundFirstDim = false;
2722bc4c3e9SNicolas Vasilache     SmallVector<int64_t> missingInnerDim;
2732bc4c3e9SNicolas Vasilache     for (size_t i = 0; i < foundDim.size(); i++) {
2742bc4c3e9SNicolas Vasilache       if (foundDim[i]) {
2752bc4c3e9SNicolas Vasilache         foundFirstDim = true;
2762bc4c3e9SNicolas Vasilache         continue;
2772bc4c3e9SNicolas Vasilache       }
2782bc4c3e9SNicolas Vasilache       if (!foundFirstDim)
2792bc4c3e9SNicolas Vasilache         continue;
2802bc4c3e9SNicolas Vasilache       // Once we found one outer dimension existing in the map keep track of all
2812bc4c3e9SNicolas Vasilache       // the missing dimensions after that.
2822bc4c3e9SNicolas Vasilache       missingInnerDim.push_back(i);
2832bc4c3e9SNicolas Vasilache       exprs.push_back(rewriter.getAffineDimExpr(i));
2842bc4c3e9SNicolas Vasilache     }
285a7a5641bSMatthias Springer     // Vector: add unit dims at the beginning of the shape.
2862bc4c3e9SNicolas Vasilache     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
2872bc4c3e9SNicolas Vasilache                                     missingInnerDim.size());
288a7a5641bSMatthias Springer     // Mask: add unit dims at the end of the shape.
289a7a5641bSMatthias Springer     Value newMask;
290a7a5641bSMatthias Springer     if (op.getMask())
291a7a5641bSMatthias Springer       newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
292a7a5641bSMatthias Springer                                missingInnerDim.size());
2932bc4c3e9SNicolas Vasilache     exprs.append(map.getResults().begin(), map.getResults().end());
2942bc4c3e9SNicolas Vasilache     AffineMap newMap =
2952bc4c3e9SNicolas Vasilache         AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
2962bc4c3e9SNicolas Vasilache     // All the new dimensions added are inbound.
2972bc4c3e9SNicolas Vasilache     SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
2986040044fSMatthias Springer     for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
2996040044fSMatthias Springer       newInBoundsValues.push_back(op.isDimInBounds(i));
3002bc4c3e9SNicolas Vasilache     }
3016040044fSMatthias Springer     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
302fdd245adSHugo Trachino     auto newWrite = rewriter.create<vector::TransferWriteOp>(
303fdd245adSHugo Trachino         op.getLoc(), newVec, op.getSource(), op.getIndices(),
304fdd245adSHugo Trachino         AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
305fdd245adSHugo Trachino     if (newWrite.hasPureTensorSemantics())
306fdd245adSHugo Trachino       return newWrite.getResult();
307fdd245adSHugo Trachino     // In the memref case there's no return value. Use empty value to signal
308fdd245adSHugo Trachino     // success.
309fdd245adSHugo Trachino     return Value();
3102bc4c3e9SNicolas Vasilache   }
3112bc4c3e9SNicolas Vasilache };
3122bc4c3e9SNicolas Vasilache 
3132bc4c3e9SNicolas Vasilache /// Lower transfer_read op with broadcast in the leading dimensions into
3142bc4c3e9SNicolas Vasilache /// transfer_read of lower rank + vector.broadcast.
3152bc4c3e9SNicolas Vasilache /// Ex: vector.transfer_read ...
3162bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
3172bc4c3e9SNicolas Vasilache /// into:
3182bc4c3e9SNicolas Vasilache ///     %v = vector.transfer_read ...
3192bc4c3e9SNicolas Vasilache ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
3202bc4c3e9SNicolas Vasilache ///     vector.broadcast %v
3210170498aSHugo Trachino struct TransferOpReduceRank
3220170498aSHugo Trachino     : public MaskableOpRewritePattern<vector::TransferReadOp> {
3230170498aSHugo Trachino   using MaskableOpRewritePattern::MaskableOpRewritePattern;
3242bc4c3e9SNicolas Vasilache 
3250170498aSHugo Trachino   FailureOr<mlir::Value>
3260170498aSHugo Trachino   matchAndRewriteMaskableOp(vector::TransferReadOp op,
3270170498aSHugo Trachino                             MaskingOpInterface maskOp,
3282bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
3292bc4c3e9SNicolas Vasilache     // TODO: support 0-d corner case.
3302bc4c3e9SNicolas Vasilache     if (op.getTransferRank() == 0)
3318b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
3320170498aSHugo Trachino     // TODO: support masked case.
3330170498aSHugo Trachino     if (maskOp)
3340170498aSHugo Trachino       return rewriter.notifyMatchFailure(op, "Masked case not supported");
3352bc4c3e9SNicolas Vasilache 
3362bc4c3e9SNicolas Vasilache     AffineMap map = op.getPermutationMap();
3372bc4c3e9SNicolas Vasilache     unsigned numLeadingBroadcast = 0;
3382bc4c3e9SNicolas Vasilache     for (auto expr : map.getResults()) {
3391609f1c2Slong.chen       auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
3402bc4c3e9SNicolas Vasilache       if (!dimExpr || dimExpr.getValue() != 0)
3412bc4c3e9SNicolas Vasilache         break;
3422bc4c3e9SNicolas Vasilache       numLeadingBroadcast++;
3432bc4c3e9SNicolas Vasilache     }
3442bc4c3e9SNicolas Vasilache     // If there are no leading zeros in the map there is nothing to do.
3452bc4c3e9SNicolas Vasilache     if (numLeadingBroadcast == 0)
3468b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
3478b513407SNicolas Vasilache 
3482bc4c3e9SNicolas Vasilache     VectorType originalVecType = op.getVectorType();
3492bc4c3e9SNicolas Vasilache     unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
3502bc4c3e9SNicolas Vasilache     // Calculate new map, vector type and masks without the leading zeros.
3512bc4c3e9SNicolas Vasilache     AffineMap newMap = AffineMap::get(
3522bc4c3e9SNicolas Vasilache         map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
3532bc4c3e9SNicolas Vasilache         op.getContext());
3542bc4c3e9SNicolas Vasilache     // Only remove the leading zeros if the rest of the map is a minor identity
3552bc4c3e9SNicolas Vasilache     // with broadasting. Otherwise we first want to permute the map.
3568b513407SNicolas Vasilache     if (!newMap.isMinorIdentityWithBroadcasting()) {
3578b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(
3588b513407SNicolas Vasilache           op, "map is not a minor identity with broadcasting");
3598b513407SNicolas Vasilache     }
3602bc4c3e9SNicolas Vasilache 
36112b49518SAndrzej Warzynski     SmallVector<int64_t> newShape(
3622bc4c3e9SNicolas Vasilache         originalVecType.getShape().take_back(reducedShapeRank));
36312b49518SAndrzej Warzynski     SmallVector<bool> newScalableDims(
36412b49518SAndrzej Warzynski         originalVecType.getScalableDims().take_back(reducedShapeRank));
3658b513407SNicolas Vasilache 
36612b49518SAndrzej Warzynski     VectorType newReadType = VectorType::get(
36712b49518SAndrzej Warzynski         newShape, originalVecType.getElementType(), newScalableDims);
3682bc4c3e9SNicolas Vasilache     ArrayAttr newInBoundsAttr =
3692bc4c3e9SNicolas Vasilache         op.getInBounds()
3702bc4c3e9SNicolas Vasilache             ? rewriter.getArrayAttr(
3712bc4c3e9SNicolas Vasilache                   op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
3722bc4c3e9SNicolas Vasilache             : ArrayAttr();
3732bc4c3e9SNicolas Vasilache     Value newRead = rewriter.create<vector::TransferReadOp>(
3742bc4c3e9SNicolas Vasilache         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
3752bc4c3e9SNicolas Vasilache         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
3762bc4c3e9SNicolas Vasilache         newInBoundsAttr);
3770170498aSHugo Trachino     return rewriter
3780170498aSHugo Trachino         .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
3790170498aSHugo Trachino         .getVector();
3802bc4c3e9SNicolas Vasilache   }
3812bc4c3e9SNicolas Vasilache };
3822bc4c3e9SNicolas Vasilache 
3832bc4c3e9SNicolas Vasilache } // namespace
3842bc4c3e9SNicolas Vasilache 
3852bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
3862bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
3872bc4c3e9SNicolas Vasilache   patterns
3882bc4c3e9SNicolas Vasilache       .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
3892bc4c3e9SNicolas Vasilache            TransferOpReduceRank, TransferWriteNonPermutationLowering>(
3902bc4c3e9SNicolas Vasilache           patterns.getContext(), benefit);
3912bc4c3e9SNicolas Vasilache }
3922bc4c3e9SNicolas Vasilache 
3932bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
3942bc4c3e9SNicolas Vasilache // populateVectorTransferLoweringPatterns
3952bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
3962bc4c3e9SNicolas Vasilache 
3972bc4c3e9SNicolas Vasilache namespace {
3982bc4c3e9SNicolas Vasilache /// Progressive lowering of transfer_read. This pattern supports lowering of
3992bc4c3e9SNicolas Vasilache /// `vector.transfer_read` to a combination of `vector.load` and
4002bc4c3e9SNicolas Vasilache /// `vector.broadcast` if all of the following hold:
4012bc4c3e9SNicolas Vasilache /// - Stride of most minor memref dimension must be 1.
4022bc4c3e9SNicolas Vasilache /// - Out-of-bounds masking is not required.
4032bc4c3e9SNicolas Vasilache /// - If the memref's element type is a vector type then it coincides with the
4042bc4c3e9SNicolas Vasilache ///   result type.
4052bc4c3e9SNicolas Vasilache /// - The permutation map doesn't perform permutation (broadcasting is allowed).
4062bc4c3e9SNicolas Vasilache struct TransferReadToVectorLoadLowering
40774941d05SHugo Trachino     : public MaskableOpRewritePattern<vector::TransferReadOp> {
4082bc4c3e9SNicolas Vasilache   TransferReadToVectorLoadLowering(MLIRContext *context,
4092bc4c3e9SNicolas Vasilache                                    std::optional<unsigned> maxRank,
4102bc4c3e9SNicolas Vasilache                                    PatternBenefit benefit = 1)
41174941d05SHugo Trachino       : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
4122bc4c3e9SNicolas Vasilache         maxTransferRank(maxRank) {}
4132bc4c3e9SNicolas Vasilache 
41474941d05SHugo Trachino   FailureOr<mlir::Value>
41574941d05SHugo Trachino   matchAndRewriteMaskableOp(vector::TransferReadOp read,
41674941d05SHugo Trachino                             MaskingOpInterface maskOp,
4172bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
4188b513407SNicolas Vasilache     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
4198b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(
4208b513407SNicolas Vasilache           read, "vector type is greater than max transfer rank");
4218b513407SNicolas Vasilache     }
4222bc4c3e9SNicolas Vasilache 
42374941d05SHugo Trachino     if (maskOp)
42474941d05SHugo Trachino       return rewriter.notifyMatchFailure(read, "Masked case not supported");
4252bc4c3e9SNicolas Vasilache     SmallVector<unsigned> broadcastedDims;
4262bc4c3e9SNicolas Vasilache     // Permutations are handled by VectorToSCF or
4272bc4c3e9SNicolas Vasilache     // populateVectorTransferPermutationMapLoweringPatterns.
4282bc4c3e9SNicolas Vasilache     // We let the 0-d corner case pass-through as it is supported.
4292bc4c3e9SNicolas Vasilache     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
4302bc4c3e9SNicolas Vasilache             &broadcastedDims))
4318b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
4322bc4c3e9SNicolas Vasilache 
4335550c821STres Popp     auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
4342bc4c3e9SNicolas Vasilache     if (!memRefType)
4358b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(read, "not a memref source");
4362bc4c3e9SNicolas Vasilache 
4372bc4c3e9SNicolas Vasilache     // Non-unit strides are handled by VectorToSCF.
438*6aaa8f25SMatthias Springer     if (!memRefType.isLastDimUnitStride())
4398b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
4402bc4c3e9SNicolas Vasilache 
4412bc4c3e9SNicolas Vasilache     // If there is broadcasting involved then we first load the unbroadcasted
4422bc4c3e9SNicolas Vasilache     // vector, and then broadcast it with `vector.broadcast`.
4432bc4c3e9SNicolas Vasilache     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
4445262865aSKazu Hirata     SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
4452bc4c3e9SNicolas Vasilache     for (unsigned i : broadcastedDims)
4462bc4c3e9SNicolas Vasilache       unbroadcastedVectorShape[i] = 1;
447ccef726dSBenjamin Maxwell     VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
4482bc4c3e9SNicolas Vasilache         unbroadcastedVectorShape, read.getVectorType().getElementType());
4492bc4c3e9SNicolas Vasilache 
4502bc4c3e9SNicolas Vasilache     // `vector.load` supports vector types as memref's elements only when the
4512bc4c3e9SNicolas Vasilache     // resulting vector type is the same as the element type.
4522bc4c3e9SNicolas Vasilache     auto memrefElTy = memRefType.getElementType();
4535550c821STres Popp     if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
4548b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(read, "incompatible element type");
4552bc4c3e9SNicolas Vasilache 
4562bc4c3e9SNicolas Vasilache     // Otherwise, element types of the memref and the vector must match.
4575550c821STres Popp     if (!isa<VectorType>(memrefElTy) &&
4582bc4c3e9SNicolas Vasilache         memrefElTy != read.getVectorType().getElementType())
4598b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(read, "non-matching element type");
4602bc4c3e9SNicolas Vasilache 
4612bc4c3e9SNicolas Vasilache     // Out-of-bounds dims are handled by MaterializeTransferMask.
4622bc4c3e9SNicolas Vasilache     if (read.hasOutOfBoundsDim())
4638b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
4642bc4c3e9SNicolas Vasilache 
4652bc4c3e9SNicolas Vasilache     // Create vector load op.
46674941d05SHugo Trachino     Operation *res;
4672bc4c3e9SNicolas Vasilache     if (read.getMask()) {
4688c07d5ecSCullen Rhodes       if (read.getVectorType().getRank() != 1)
4698c07d5ecSCullen Rhodes         // vector.maskedload operates on 1-D vectors.
4708c07d5ecSCullen Rhodes         return rewriter.notifyMatchFailure(
4718c07d5ecSCullen Rhodes             read, "vector type is not rank 1, can't create masked load, needs "
4728c07d5ecSCullen Rhodes                   "VectorToSCF");
4738c07d5ecSCullen Rhodes 
4742bc4c3e9SNicolas Vasilache       Value fill = rewriter.create<vector::SplatOp>(
4752bc4c3e9SNicolas Vasilache           read.getLoc(), unbroadcastedVectorType, read.getPadding());
47674941d05SHugo Trachino       res = rewriter.create<vector::MaskedLoadOp>(
4772bc4c3e9SNicolas Vasilache           read.getLoc(), unbroadcastedVectorType, read.getSource(),
4782bc4c3e9SNicolas Vasilache           read.getIndices(), read.getMask(), fill);
4792bc4c3e9SNicolas Vasilache     } else {
48074941d05SHugo Trachino       res = rewriter.create<vector::LoadOp>(
4812bc4c3e9SNicolas Vasilache           read.getLoc(), unbroadcastedVectorType, read.getSource(),
4822bc4c3e9SNicolas Vasilache           read.getIndices());
4832bc4c3e9SNicolas Vasilache     }
4842bc4c3e9SNicolas Vasilache 
4852bc4c3e9SNicolas Vasilache     // Insert a broadcasting op if required.
48674941d05SHugo Trachino     if (!broadcastedDims.empty())
48774941d05SHugo Trachino       res = rewriter.create<vector::BroadcastOp>(
48874941d05SHugo Trachino           read.getLoc(), read.getVectorType(), res->getResult(0));
48974941d05SHugo Trachino     return res->getResult(0);
4902bc4c3e9SNicolas Vasilache   }
4912bc4c3e9SNicolas Vasilache 
4922bc4c3e9SNicolas Vasilache   std::optional<unsigned> maxTransferRank;
4932bc4c3e9SNicolas Vasilache };
4942bc4c3e9SNicolas Vasilache 
4952bc4c3e9SNicolas Vasilache /// Progressive lowering of transfer_write. This pattern supports lowering of
4962bc4c3e9SNicolas Vasilache /// `vector.transfer_write` to `vector.store` if all of the following hold:
4972bc4c3e9SNicolas Vasilache /// - Stride of most minor memref dimension must be 1.
4982bc4c3e9SNicolas Vasilache /// - Out-of-bounds masking is not required.
4992bc4c3e9SNicolas Vasilache /// - If the memref's element type is a vector type then it coincides with the
5002bc4c3e9SNicolas Vasilache ///   type of the written value.
5012bc4c3e9SNicolas Vasilache /// - The permutation map is the minor identity map (neither permutation nor
5022bc4c3e9SNicolas Vasilache ///   broadcasting is allowed).
5032bc4c3e9SNicolas Vasilache struct TransferWriteToVectorStoreLowering
50474941d05SHugo Trachino     : public MaskableOpRewritePattern<vector::TransferWriteOp> {
5052bc4c3e9SNicolas Vasilache   TransferWriteToVectorStoreLowering(MLIRContext *context,
5062bc4c3e9SNicolas Vasilache                                      std::optional<unsigned> maxRank,
5072bc4c3e9SNicolas Vasilache                                      PatternBenefit benefit = 1)
50874941d05SHugo Trachino       : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
5092bc4c3e9SNicolas Vasilache         maxTransferRank(maxRank) {}
5102bc4c3e9SNicolas Vasilache 
51174941d05SHugo Trachino   FailureOr<mlir::Value>
51274941d05SHugo Trachino   matchAndRewriteMaskableOp(vector::TransferWriteOp write,
51374941d05SHugo Trachino                             MaskingOpInterface maskOp,
5142bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
5158b513407SNicolas Vasilache     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
5168b513407SNicolas Vasilache       return rewriter.notifyMatchFailure(
5178b513407SNicolas Vasilache           write, "vector type is greater than max transfer rank");
5188b513407SNicolas Vasilache     }
51974941d05SHugo Trachino     if (maskOp)
52074941d05SHugo Trachino       return rewriter.notifyMatchFailure(write, "Masked case not supported");
5212bc4c3e9SNicolas Vasilache 
5222bc4c3e9SNicolas Vasilache     // Permutations are handled by VectorToSCF or
5232bc4c3e9SNicolas Vasilache     // populateVectorTransferPermutationMapLoweringPatterns.
5242bc4c3e9SNicolas Vasilache     if ( // pass-through for the 0-d corner case.
5252bc4c3e9SNicolas Vasilache         !write.getPermutationMap().isMinorIdentity())
5262bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
5272bc4c3e9SNicolas Vasilache         diag << "permutation map is not minor identity: " << write;
5282bc4c3e9SNicolas Vasilache       });
5292bc4c3e9SNicolas Vasilache 
5305550c821STres Popp     auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
5312bc4c3e9SNicolas Vasilache     if (!memRefType)
5322bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
5332bc4c3e9SNicolas Vasilache         diag << "not a memref type: " << write;
5342bc4c3e9SNicolas Vasilache       });
5352bc4c3e9SNicolas Vasilache 
5362bc4c3e9SNicolas Vasilache     // Non-unit strides are handled by VectorToSCF.
537*6aaa8f25SMatthias Springer     if (!memRefType.isLastDimUnitStride())
5382bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
5392bc4c3e9SNicolas Vasilache         diag << "most minor stride is not 1: " << write;
5402bc4c3e9SNicolas Vasilache       });
5412bc4c3e9SNicolas Vasilache 
5422bc4c3e9SNicolas Vasilache     // `vector.store` supports vector types as memref's elements only when the
5432bc4c3e9SNicolas Vasilache     // type of the vector value being written is the same as the element type.
5442bc4c3e9SNicolas Vasilache     auto memrefElTy = memRefType.getElementType();
5455550c821STres Popp     if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
5462bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
5472bc4c3e9SNicolas Vasilache         diag << "elemental type mismatch: " << write;
5482bc4c3e9SNicolas Vasilache       });
5492bc4c3e9SNicolas Vasilache 
5502bc4c3e9SNicolas Vasilache     // Otherwise, element types of the memref and the vector must match.
5515550c821STres Popp     if (!isa<VectorType>(memrefElTy) &&
5522bc4c3e9SNicolas Vasilache         memrefElTy != write.getVectorType().getElementType())
5532bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
5542bc4c3e9SNicolas Vasilache         diag << "elemental type mismatch: " << write;
5552bc4c3e9SNicolas Vasilache       });
5562bc4c3e9SNicolas Vasilache 
5572bc4c3e9SNicolas Vasilache     // Out-of-bounds dims are handled by MaterializeTransferMask.
5582bc4c3e9SNicolas Vasilache     if (write.hasOutOfBoundsDim())
5592bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
5602bc4c3e9SNicolas Vasilache         diag << "out of bounds dim: " << write;
5612bc4c3e9SNicolas Vasilache       });
5622bc4c3e9SNicolas Vasilache     if (write.getMask()) {
5638c07d5ecSCullen Rhodes       if (write.getVectorType().getRank() != 1)
5648c07d5ecSCullen Rhodes         // vector.maskedstore operates on 1-D vectors.
5658c07d5ecSCullen Rhodes         return rewriter.notifyMatchFailure(
5668c07d5ecSCullen Rhodes             write.getLoc(), [=](Diagnostic &diag) {
5678c07d5ecSCullen Rhodes               diag << "vector type is not rank 1, can't create masked store, "
5688c07d5ecSCullen Rhodes                       "needs VectorToSCF: "
5698c07d5ecSCullen Rhodes                    << write;
5708c07d5ecSCullen Rhodes             });
5718c07d5ecSCullen Rhodes 
57274941d05SHugo Trachino       rewriter.create<vector::MaskedStoreOp>(
57374941d05SHugo Trachino           write.getLoc(), write.getSource(), write.getIndices(),
57474941d05SHugo Trachino           write.getMask(), write.getVector());
5752bc4c3e9SNicolas Vasilache     } else {
57674941d05SHugo Trachino       rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
57774941d05SHugo Trachino                                        write.getSource(), write.getIndices());
5782bc4c3e9SNicolas Vasilache     }
57974941d05SHugo Trachino     // There's no return value for StoreOps. Use Value() to signal success to
58074941d05SHugo Trachino     // matchAndRewrite.
58174941d05SHugo Trachino     return Value();
5822bc4c3e9SNicolas Vasilache   }
5832bc4c3e9SNicolas Vasilache 
5842bc4c3e9SNicolas Vasilache   std::optional<unsigned> maxTransferRank;
5852bc4c3e9SNicolas Vasilache };
5862bc4c3e9SNicolas Vasilache } // namespace
5872bc4c3e9SNicolas Vasilache 
5882bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorTransferLoweringPatterns(
5892bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
5902bc4c3e9SNicolas Vasilache     PatternBenefit benefit) {
5912bc4c3e9SNicolas Vasilache   patterns.add<TransferReadToVectorLoadLowering,
5922bc4c3e9SNicolas Vasilache                TransferWriteToVectorStoreLowering>(patterns.getContext(),
5932bc4c3e9SNicolas Vasilache                                                    maxTransferRank, benefit);
5942bc4c3e9SNicolas Vasilache }
595