xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (revision ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6)
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18 #include "mlir/Interfaces/VectorInterfaces.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24 /// permutation based on the given indices.
25 static ArrayAttr
26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27                              const SmallVector<unsigned> &permutation) {
28   SmallVector<bool> newInBoundsValues(permutation.size());
29   size_t index = 0;
30   for (unsigned pos : permutation)
31     newInBoundsValues[pos] =
32         cast<BoolAttr>(attr.getValue()[index++]).getValue();
33   return builder.getBoolArrayAttr(newInBoundsValues);
34 }
35 
36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37 /// dimensions.
38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39                               int64_t addedRank) {
40   auto originalVecType = cast<VectorType>(vec.getType());
41   SmallVector<int64_t> newShape(addedRank, 1);
42   newShape.append(originalVecType.getShape().begin(),
43                   originalVecType.getShape().end());
44   VectorType newVecType =
45       VectorType::get(newShape, originalVecType.getElementType());
46   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
47 }
48 
49 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
50 /// dimensions.
51 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
52                             int64_t addedRank) {
53   Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
54   SmallVector<int64_t> permutation;
55   for (int64_t i = addedRank,
56                e = broadcasted.getType().cast<VectorType>().getRank();
57        i < e; ++i)
58     permutation.push_back(i);
59   for (int64_t i = 0; i < addedRank; ++i)
60     permutation.push_back(i);
61   return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // populateVectorTransferPermutationMapLoweringPatterns
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// Lower transfer_read op with permutation into a transfer_read with a
70 /// permutation map composed of leading zeros followed by a minor identiy +
71 /// vector.transpose op.
72 /// Ex:
73 ///     vector.transfer_read ...
74 ///         permutation_map: (d0, d1, d2) -> (0, d1)
75 /// into:
76 ///     %v = vector.transfer_read ...
77 ///         permutation_map: (d0, d1, d2) -> (d1, 0)
78 ///     vector.transpose %v, [1, 0]
79 ///
80 ///     vector.transfer_read ...
81 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
82 /// into:
83 ///     %v = vector.transfer_read ...
84 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
85 ///     vector.transpose %v, [0, 1, 3, 2, 4]
86 /// Note that an alternative is to transform it to linalg.transpose +
87 /// vector.transfer_read to do the transpose in memory instead.
88 struct TransferReadPermutationLowering
89     : public OpRewritePattern<vector::TransferReadOp> {
90   using OpRewritePattern::OpRewritePattern;
91 
92   LogicalResult matchAndRewrite(vector::TransferReadOp op,
93                                 PatternRewriter &rewriter) const override {
94     // TODO: support 0-d corner case.
95     if (op.getTransferRank() == 0)
96       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
97 
98     SmallVector<unsigned> permutation;
99     AffineMap map = op.getPermutationMap();
100     if (map.getNumResults() == 0)
101       return rewriter.notifyMatchFailure(op, "0 result permutation map");
102     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
103       return rewriter.notifyMatchFailure(
104           op, "map is not permutable to minor identity, apply another pattern");
105     }
106     AffineMap permutationMap =
107         map.getPermutationMap(permutation, op.getContext());
108     if (permutationMap.isIdentity())
109       return rewriter.notifyMatchFailure(op, "map is not identity");
110 
111     permutationMap = map.getPermutationMap(permutation, op.getContext());
112     // Caluclate the map of the new read by applying the inverse permutation.
113     permutationMap = inversePermutation(permutationMap);
114     AffineMap newMap = permutationMap.compose(map);
115     // Apply the reverse transpose to deduce the type of the transfer_read.
116     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
117     SmallVector<int64_t> newVectorShape(originalShape.size());
118     ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
119     SmallVector<bool> newScalableDims(originalShape.size());
120     for (const auto &pos : llvm::enumerate(permutation)) {
121       newVectorShape[pos.value()] = originalShape[pos.index()];
122       newScalableDims[pos.value()] = originalScalableDims[pos.index()];
123     }
124 
125     // Transpose in_bounds attribute.
126     ArrayAttr newInBoundsAttr =
127         op.getInBounds() ? inverseTransposeInBoundsAttr(
128                                rewriter, op.getInBounds().value(), permutation)
129                          : ArrayAttr();
130 
131     // Generate new transfer_read operation.
132     VectorType newReadType = VectorType::get(
133         newVectorShape, op.getVectorType().getElementType(), newScalableDims);
134     Value newRead = rewriter.create<vector::TransferReadOp>(
135         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
136         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
137         newInBoundsAttr);
138 
139     // Transpose result of transfer_read.
140     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
141     rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
142                                                      transposePerm);
143     return success();
144   }
145 };
146 
147 /// Lower transfer_write op with permutation into a transfer_write with a
148 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
149 /// Ex:
150 ///     vector.transfer_write %v ...
151 ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
152 /// into:
153 ///     %tmp = vector.transpose %v, [2, 0, 1]
154 ///     vector.transfer_write %tmp ...
155 ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
156 ///
157 ///     vector.transfer_write %v ...
158 ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
159 /// into:
160 ///     %tmp = vector.transpose %v, [1, 0]
161 ///     %v = vector.transfer_write %tmp ...
162 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
163 struct TransferWritePermutationLowering
164     : public OpRewritePattern<vector::TransferWriteOp> {
165   using OpRewritePattern::OpRewritePattern;
166 
167   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
168                                 PatternRewriter &rewriter) const override {
169     // TODO: support 0-d corner case.
170     if (op.getTransferRank() == 0)
171       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
172 
173     SmallVector<unsigned> permutation;
174     AffineMap map = op.getPermutationMap();
175     if (map.isMinorIdentity())
176       return rewriter.notifyMatchFailure(op, "map is already minor identity");
177 
178     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
179       return rewriter.notifyMatchFailure(
180           op, "map is not permutable to minor identity, apply another pattern");
181     }
182 
183     // Remove unused dims from the permutation map. E.g.:
184     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
185     // comp = (d0, d1, d2) -> (d2, d0, d1)
186     auto comp = compressUnusedDims(map);
187     AffineMap permutationMap = inversePermutation(comp);
188     // Get positions of remaining result dims.
189     SmallVector<int64_t> indices;
190     llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
191                     [](AffineExpr expr) {
192                       return expr.dyn_cast<AffineDimExpr>().getPosition();
193                     });
194 
195     // Transpose in_bounds attribute.
196     ArrayAttr newInBoundsAttr =
197         op.getInBounds() ? inverseTransposeInBoundsAttr(
198                                rewriter, op.getInBounds().value(), permutation)
199                          : ArrayAttr();
200 
201     // Generate new transfer_write operation.
202     Value newVec = rewriter.create<vector::TransposeOp>(
203         op.getLoc(), op.getVector(), indices);
204     auto newMap = AffineMap::getMinorIdentityMap(
205         map.getNumDims(), map.getNumResults(), rewriter.getContext());
206     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
207         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
208         op.getMask(), newInBoundsAttr);
209 
210     return success();
211   }
212 };
213 
214 /// Convert a transfer.write op with a map which isn't the permutation of a
215 /// minor identity into a vector.broadcast + transfer_write with permutation of
216 /// minor identity map by adding unit dim on inner dimension. Ex:
217 /// ```
218 ///   vector.transfer_write %v
219 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
220 ///     vector<8x16xf32>
221 /// ```
222 /// into:
223 /// ```
224 ///   %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
225 ///   vector.transfer_write %v1
226 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
227 ///     vector<1x8x16xf32>
228 /// ```
229 struct TransferWriteNonPermutationLowering
230     : public OpRewritePattern<vector::TransferWriteOp> {
231   using OpRewritePattern::OpRewritePattern;
232 
233   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
234                                 PatternRewriter &rewriter) const override {
235     // TODO: support 0-d corner case.
236     if (op.getTransferRank() == 0)
237       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
238 
239     SmallVector<unsigned> permutation;
240     AffineMap map = op.getPermutationMap();
241     if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
242       return rewriter.notifyMatchFailure(
243           op,
244           "map is already permutable to minor identity, apply another pattern");
245     }
246 
247     // Missing outer dimensions are allowed, find the most outer existing
248     // dimension then deduce the missing inner dimensions.
249     SmallVector<bool> foundDim(map.getNumDims(), false);
250     for (AffineExpr exp : map.getResults())
251       foundDim[exp.cast<AffineDimExpr>().getPosition()] = true;
252     SmallVector<AffineExpr> exprs;
253     bool foundFirstDim = false;
254     SmallVector<int64_t> missingInnerDim;
255     for (size_t i = 0; i < foundDim.size(); i++) {
256       if (foundDim[i]) {
257         foundFirstDim = true;
258         continue;
259       }
260       if (!foundFirstDim)
261         continue;
262       // Once we found one outer dimension existing in the map keep track of all
263       // the missing dimensions after that.
264       missingInnerDim.push_back(i);
265       exprs.push_back(rewriter.getAffineDimExpr(i));
266     }
267     // Vector: add unit dims at the beginning of the shape.
268     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
269                                     missingInnerDim.size());
270     // Mask: add unit dims at the end of the shape.
271     Value newMask;
272     if (op.getMask())
273       newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
274                                missingInnerDim.size());
275     exprs.append(map.getResults().begin(), map.getResults().end());
276     AffineMap newMap =
277         AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
278     // All the new dimensions added are inbound.
279     SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
280     for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
281       newInBoundsValues.push_back(op.isDimInBounds(i));
282     }
283     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
284     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
285         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
286         newMask, newInBoundsAttr);
287     return success();
288   }
289 };
290 
291 /// Lower transfer_read op with broadcast in the leading dimensions into
292 /// transfer_read of lower rank + vector.broadcast.
293 /// Ex: vector.transfer_read ...
294 ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
295 /// into:
296 ///     %v = vector.transfer_read ...
297 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
298 ///     vector.broadcast %v
299 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
300   using OpRewritePattern::OpRewritePattern;
301 
302   LogicalResult matchAndRewrite(vector::TransferReadOp op,
303                                 PatternRewriter &rewriter) const override {
304     // TODO: support 0-d corner case.
305     if (op.getTransferRank() == 0)
306       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
307 
308     AffineMap map = op.getPermutationMap();
309     unsigned numLeadingBroadcast = 0;
310     for (auto expr : map.getResults()) {
311       auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
312       if (!dimExpr || dimExpr.getValue() != 0)
313         break;
314       numLeadingBroadcast++;
315     }
316     // If there are no leading zeros in the map there is nothing to do.
317     if (numLeadingBroadcast == 0)
318       return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
319 
320     VectorType originalVecType = op.getVectorType();
321     unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
322     // Calculate new map, vector type and masks without the leading zeros.
323     AffineMap newMap = AffineMap::get(
324         map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
325         op.getContext());
326     // Only remove the leading zeros if the rest of the map is a minor identity
327     // with broadasting. Otherwise we first want to permute the map.
328     if (!newMap.isMinorIdentityWithBroadcasting()) {
329       return rewriter.notifyMatchFailure(
330           op, "map is not a minor identity with broadcasting");
331     }
332 
333     // TODO: support zero-dimension vectors natively.  See:
334     // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
335     // In the meantime, lower these to a scalar load when they pop up.
336     if (reducedShapeRank == 0) {
337       Value newRead;
338       if (isa<TensorType>(op.getShapedType())) {
339         newRead = rewriter.create<tensor::ExtractOp>(
340             op.getLoc(), op.getSource(), op.getIndices());
341       } else {
342         newRead = rewriter.create<memref::LoadOp>(
343             op.getLoc(), originalVecType.getElementType(), op.getSource(),
344             op.getIndices());
345       }
346       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
347                                                        newRead);
348       return success();
349     }
350 
351     SmallVector<int64_t> newShape(
352         originalVecType.getShape().take_back(reducedShapeRank));
353     SmallVector<bool> newScalableDims(
354         originalVecType.getScalableDims().take_back(reducedShapeRank));
355     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
356     if (newShape.empty())
357       return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
358 
359     VectorType newReadType = VectorType::get(
360         newShape, originalVecType.getElementType(), newScalableDims);
361     ArrayAttr newInBoundsAttr =
362         op.getInBounds()
363             ? rewriter.getArrayAttr(
364                   op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
365             : ArrayAttr();
366     Value newRead = rewriter.create<vector::TransferReadOp>(
367         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
368         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
369         newInBoundsAttr);
370     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
371                                                      newRead);
372     return success();
373   }
374 };
375 
376 } // namespace
377 
378 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
379     RewritePatternSet &patterns, PatternBenefit benefit) {
380   patterns
381       .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
382            TransferOpReduceRank, TransferWriteNonPermutationLowering>(
383           patterns.getContext(), benefit);
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // populateVectorTransferLoweringPatterns
388 //===----------------------------------------------------------------------===//
389 
390 namespace {
391 /// Progressive lowering of transfer_read. This pattern supports lowering of
392 /// `vector.transfer_read` to a combination of `vector.load` and
393 /// `vector.broadcast` if all of the following hold:
394 /// - Stride of most minor memref dimension must be 1.
395 /// - Out-of-bounds masking is not required.
396 /// - If the memref's element type is a vector type then it coincides with the
397 ///   result type.
398 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
399 struct TransferReadToVectorLoadLowering
400     : public OpRewritePattern<vector::TransferReadOp> {
401   TransferReadToVectorLoadLowering(MLIRContext *context,
402                                    std::optional<unsigned> maxRank,
403                                    PatternBenefit benefit = 1)
404       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
405         maxTransferRank(maxRank) {}
406 
407   LogicalResult matchAndRewrite(vector::TransferReadOp read,
408                                 PatternRewriter &rewriter) const override {
409     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
410       return rewriter.notifyMatchFailure(
411           read, "vector type is greater than max transfer rank");
412     }
413 
414     SmallVector<unsigned> broadcastedDims;
415     // Permutations are handled by VectorToSCF or
416     // populateVectorTransferPermutationMapLoweringPatterns.
417     // We let the 0-d corner case pass-through as it is supported.
418     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
419             &broadcastedDims))
420       return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
421 
422     auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
423     if (!memRefType)
424       return rewriter.notifyMatchFailure(read, "not a memref source");
425 
426     // Non-unit strides are handled by VectorToSCF.
427     if (!isLastMemrefDimUnitStride(memRefType))
428       return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
429 
430     // If there is broadcasting involved then we first load the unbroadcasted
431     // vector, and then broadcast it with `vector.broadcast`.
432     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
433     SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
434                                                   vectorShape.end());
435     for (unsigned i : broadcastedDims)
436       unbroadcastedVectorShape[i] = 1;
437     VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
438         unbroadcastedVectorShape, read.getVectorType().getElementType());
439 
440     // `vector.load` supports vector types as memref's elements only when the
441     // resulting vector type is the same as the element type.
442     auto memrefElTy = memRefType.getElementType();
443     if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
444       return rewriter.notifyMatchFailure(read, "incompatible element type");
445 
446     // Otherwise, element types of the memref and the vector must match.
447     if (!isa<VectorType>(memrefElTy) &&
448         memrefElTy != read.getVectorType().getElementType())
449       return rewriter.notifyMatchFailure(read, "non-matching element type");
450 
451     // Out-of-bounds dims are handled by MaterializeTransferMask.
452     if (read.hasOutOfBoundsDim())
453       return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
454 
455     // Create vector load op.
456     Operation *loadOp;
457     if (read.getMask()) {
458       Value fill = rewriter.create<vector::SplatOp>(
459           read.getLoc(), unbroadcastedVectorType, read.getPadding());
460       loadOp = rewriter.create<vector::MaskedLoadOp>(
461           read.getLoc(), unbroadcastedVectorType, read.getSource(),
462           read.getIndices(), read.getMask(), fill);
463     } else {
464       loadOp = rewriter.create<vector::LoadOp>(
465           read.getLoc(), unbroadcastedVectorType, read.getSource(),
466           read.getIndices());
467     }
468 
469     // Insert a broadcasting op if required.
470     if (!broadcastedDims.empty()) {
471       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
472           read, read.getVectorType(), loadOp->getResult(0));
473     } else {
474       rewriter.replaceOp(read, loadOp->getResult(0));
475     }
476 
477     return success();
478   }
479 
480   std::optional<unsigned> maxTransferRank;
481 };
482 
483 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
484 // TODO: we shouldn't cross the vector/scalar domains just for this
485 // but atm we lack the infra to avoid it. Possible solutions include:
486 // - go directly to LLVM + bitcast
487 // - introduce a bitcast op and likely a new pointer dialect
488 // - let memref.load/store additionally support the 0-d vector case
489 // There are still deeper data layout issues lingering even in this
490 // trivial case (for architectures for which this matters).
491 struct VectorLoadToMemrefLoadLowering
492     : public OpRewritePattern<vector::LoadOp> {
493   using OpRewritePattern::OpRewritePattern;
494 
495   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
496                                 PatternRewriter &rewriter) const override {
497     auto vecType = loadOp.getVectorType();
498     if (vecType.getNumElements() != 1)
499       return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
500 
501     auto memrefLoad = rewriter.create<memref::LoadOp>(
502         loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
503     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
504                                                      memrefLoad);
505     return success();
506   }
507 };
508 
509 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
510 struct VectorStoreToMemrefStoreLowering
511     : public OpRewritePattern<vector::StoreOp> {
512   using OpRewritePattern::OpRewritePattern;
513 
514   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
515                                 PatternRewriter &rewriter) const override {
516     auto vecType = storeOp.getVectorType();
517     if (vecType.getNumElements() != 1)
518       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
519 
520     Value extracted;
521     if (vecType.getRank() == 0) {
522       // TODO: Unifiy once ExtractOp supports 0-d vectors.
523       extracted = rewriter.create<vector::ExtractElementOp>(
524           storeOp.getLoc(), storeOp.getValueToStore());
525     } else {
526       SmallVector<int64_t> indices(vecType.getRank(), 0);
527       extracted = rewriter.create<vector::ExtractOp>(
528           storeOp.getLoc(), storeOp.getValueToStore(), indices);
529     }
530 
531     rewriter.replaceOpWithNewOp<memref::StoreOp>(
532         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
533     return success();
534   }
535 };
536 
537 /// Progressive lowering of transfer_write. This pattern supports lowering of
538 /// `vector.transfer_write` to `vector.store` if all of the following hold:
539 /// - Stride of most minor memref dimension must be 1.
540 /// - Out-of-bounds masking is not required.
541 /// - If the memref's element type is a vector type then it coincides with the
542 ///   type of the written value.
543 /// - The permutation map is the minor identity map (neither permutation nor
544 ///   broadcasting is allowed).
545 struct TransferWriteToVectorStoreLowering
546     : public OpRewritePattern<vector::TransferWriteOp> {
547   TransferWriteToVectorStoreLowering(MLIRContext *context,
548                                      std::optional<unsigned> maxRank,
549                                      PatternBenefit benefit = 1)
550       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
551         maxTransferRank(maxRank) {}
552 
553   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
554                                 PatternRewriter &rewriter) const override {
555     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
556       return rewriter.notifyMatchFailure(
557           write, "vector type is greater than max transfer rank");
558     }
559 
560     // Permutations are handled by VectorToSCF or
561     // populateVectorTransferPermutationMapLoweringPatterns.
562     if ( // pass-through for the 0-d corner case.
563         !write.getPermutationMap().isMinorIdentity())
564       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
565         diag << "permutation map is not minor identity: " << write;
566       });
567 
568     auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
569     if (!memRefType)
570       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
571         diag << "not a memref type: " << write;
572       });
573 
574     // Non-unit strides are handled by VectorToSCF.
575     if (!isLastMemrefDimUnitStride(memRefType))
576       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
577         diag << "most minor stride is not 1: " << write;
578       });
579 
580     // `vector.store` supports vector types as memref's elements only when the
581     // type of the vector value being written is the same as the element type.
582     auto memrefElTy = memRefType.getElementType();
583     if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
584       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
585         diag << "elemental type mismatch: " << write;
586       });
587 
588     // Otherwise, element types of the memref and the vector must match.
589     if (!isa<VectorType>(memrefElTy) &&
590         memrefElTy != write.getVectorType().getElementType())
591       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
592         diag << "elemental type mismatch: " << write;
593       });
594 
595     // Out-of-bounds dims are handled by MaterializeTransferMask.
596     if (write.hasOutOfBoundsDim())
597       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
598         diag << "out of bounds dim: " << write;
599       });
600     if (write.getMask()) {
601       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
602           write, write.getSource(), write.getIndices(), write.getMask(),
603           write.getVector());
604     } else {
605       rewriter.replaceOpWithNewOp<vector::StoreOp>(
606           write, write.getVector(), write.getSource(), write.getIndices());
607     }
608     return success();
609   }
610 
611   std::optional<unsigned> maxTransferRank;
612 };
613 } // namespace
614 
615 void mlir::vector::populateVectorTransferLoweringPatterns(
616     RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
617     PatternBenefit benefit) {
618   patterns.add<TransferReadToVectorLoadLowering,
619                TransferWriteToVectorStoreLowering>(patterns.getContext(),
620                                                    maxTransferRank, benefit);
621   patterns
622       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
623           patterns.getContext(), benefit);
624 }
625