xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18 #include "mlir/Interfaces/VectorInterfaces.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24 /// permutation based on the given indices.
25 static ArrayAttr
26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27                              const SmallVector<unsigned> &permutation) {
28   SmallVector<bool> newInBoundsValues(permutation.size());
29   size_t index = 0;
30   for (unsigned pos : permutation)
31     newInBoundsValues[pos] =
32         cast<BoolAttr>(attr.getValue()[index++]).getValue();
33   return builder.getBoolArrayAttr(newInBoundsValues);
34 }
35 
36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37 /// dimensions.
38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39                               int64_t addedRank) {
40   auto originalVecType = cast<VectorType>(vec.getType());
41   SmallVector<int64_t> newShape(addedRank, 1);
42   newShape.append(originalVecType.getShape().begin(),
43                   originalVecType.getShape().end());
44 
45   SmallVector<bool> newScalableDims(addedRank, false);
46   newScalableDims.append(originalVecType.getScalableDims().begin(),
47                          originalVecType.getScalableDims().end());
48   VectorType newVecType = VectorType::get(
49       newShape, originalVecType.getElementType(), newScalableDims);
50   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
51 }
52 
53 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
54 /// dimensions.
55 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
56                             int64_t addedRank) {
57   Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
58   SmallVector<int64_t> permutation;
59   for (int64_t i = addedRank,
60                e = cast<VectorType>(broadcasted.getType()).getRank();
61        i < e; ++i)
62     permutation.push_back(i);
63   for (int64_t i = 0; i < addedRank; ++i)
64     permutation.push_back(i);
65   return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // populateVectorTransferPermutationMapLoweringPatterns
70 //===----------------------------------------------------------------------===//
71 
72 namespace {
73 /// Lower transfer_read op with permutation into a transfer_read with a
74 /// permutation map composed of leading zeros followed by a minor identiy +
75 /// vector.transpose op.
76 /// Ex:
77 ///     vector.transfer_read ...
78 ///         permutation_map: (d0, d1, d2) -> (0, d1)
79 /// into:
80 ///     %v = vector.transfer_read ...
81 ///         permutation_map: (d0, d1, d2) -> (d1, 0)
82 ///     vector.transpose %v, [1, 0]
83 ///
84 ///     vector.transfer_read ...
85 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
86 /// into:
87 ///     %v = vector.transfer_read ...
88 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
89 ///     vector.transpose %v, [0, 1, 3, 2, 4]
90 /// Note that an alternative is to transform it to linalg.transpose +
91 /// vector.transfer_read to do the transpose in memory instead.
92 struct TransferReadPermutationLowering
93     : public MaskableOpRewritePattern<vector::TransferReadOp> {
94   using MaskableOpRewritePattern::MaskableOpRewritePattern;
95 
96   FailureOr<mlir::Value>
97   matchAndRewriteMaskableOp(vector::TransferReadOp op,
98                             MaskingOpInterface maskOp,
99                             PatternRewriter &rewriter) const override {
100     // TODO: support 0-d corner case.
101     if (op.getTransferRank() == 0)
102       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103     // TODO: Support transfer_read inside MaskOp case.
104     if (maskOp)
105       return rewriter.notifyMatchFailure(op, "Masked case not supported");
106 
107     SmallVector<unsigned> permutation;
108     AffineMap map = op.getPermutationMap();
109     if (map.getNumResults() == 0)
110       return rewriter.notifyMatchFailure(op, "0 result permutation map");
111     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
112       return rewriter.notifyMatchFailure(
113           op, "map is not permutable to minor identity, apply another pattern");
114     }
115     AffineMap permutationMap =
116         map.getPermutationMap(permutation, op.getContext());
117     if (permutationMap.isIdentity())
118       return rewriter.notifyMatchFailure(op, "map is not identity");
119 
120     permutationMap = map.getPermutationMap(permutation, op.getContext());
121     // Caluclate the map of the new read by applying the inverse permutation.
122     permutationMap = inversePermutation(permutationMap);
123     AffineMap newMap = permutationMap.compose(map);
124     // Apply the reverse transpose to deduce the type of the transfer_read.
125     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
126     SmallVector<int64_t> newVectorShape(originalShape.size());
127     ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
128     SmallVector<bool> newScalableDims(originalShape.size());
129     for (const auto &pos : llvm::enumerate(permutation)) {
130       newVectorShape[pos.value()] = originalShape[pos.index()];
131       newScalableDims[pos.value()] = originalScalableDims[pos.index()];
132     }
133 
134     // Transpose in_bounds attribute.
135     ArrayAttr newInBoundsAttr =
136         inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
137 
138     // Generate new transfer_read operation.
139     VectorType newReadType = VectorType::get(
140         newVectorShape, op.getVectorType().getElementType(), newScalableDims);
141     Value newRead = rewriter.create<vector::TransferReadOp>(
142         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
143         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
144         newInBoundsAttr);
145 
146     // Transpose result of transfer_read.
147     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
148     return rewriter
149         .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
150         .getResult();
151   }
152 };
153 
154 /// Lower transfer_write op with permutation into a transfer_write with a
155 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
156 /// Ex:
157 ///     vector.transfer_write %v ...
158 ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
159 /// into:
160 ///     %tmp = vector.transpose %v, [2, 0, 1]
161 ///     vector.transfer_write %tmp ...
162 ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
163 ///
164 ///     vector.transfer_write %v ...
165 ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
166 /// into:
167 ///     %tmp = vector.transpose %v, [1, 0]
168 ///     %v = vector.transfer_write %tmp ...
169 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
170 struct TransferWritePermutationLowering
171     : public MaskableOpRewritePattern<vector::TransferWriteOp> {
172   using MaskableOpRewritePattern::MaskableOpRewritePattern;
173 
174   FailureOr<mlir::Value>
175   matchAndRewriteMaskableOp(vector::TransferWriteOp op,
176                             MaskingOpInterface maskOp,
177                             PatternRewriter &rewriter) const override {
178     // TODO: support 0-d corner case.
179     if (op.getTransferRank() == 0)
180       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
181     // TODO: Support transfer_write inside MaskOp case.
182     if (maskOp)
183       return rewriter.notifyMatchFailure(op, "Masked case not supported");
184 
185     SmallVector<unsigned> permutation;
186     AffineMap map = op.getPermutationMap();
187     if (map.isMinorIdentity())
188       return rewriter.notifyMatchFailure(op, "map is already minor identity");
189 
190     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
191       return rewriter.notifyMatchFailure(
192           op, "map is not permutable to minor identity, apply another pattern");
193     }
194 
195     // Remove unused dims from the permutation map. E.g.:
196     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
197     // comp = (d0, d1, d2) -> (d2, d0, d1)
198     auto comp = compressUnusedDims(map);
199     AffineMap permutationMap = inversePermutation(comp);
200     // Get positions of remaining result dims.
201     SmallVector<int64_t> indices;
202     llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
203                     [](AffineExpr expr) {
204                       return dyn_cast<AffineDimExpr>(expr).getPosition();
205                     });
206 
207     // Transpose in_bounds attribute.
208     ArrayAttr newInBoundsAttr =
209         inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
210 
211     // Generate new transfer_write operation.
212     Value newVec = rewriter.create<vector::TransposeOp>(
213         op.getLoc(), op.getVector(), indices);
214     auto newMap = AffineMap::getMinorIdentityMap(
215         map.getNumDims(), map.getNumResults(), rewriter.getContext());
216     auto newWrite = rewriter.create<vector::TransferWriteOp>(
217         op.getLoc(), newVec, op.getSource(), op.getIndices(),
218         AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
219     if (newWrite.hasPureTensorSemantics())
220       return newWrite.getResult();
221     // In the memref case there's no return value. Use empty value to signal
222     // success.
223     return Value();
224   }
225 };
226 
227 /// Convert a transfer.write op with a map which isn't the permutation of a
228 /// minor identity into a vector.broadcast + transfer_write with permutation of
229 /// minor identity map by adding unit dim on inner dimension. Ex:
230 /// ```
231 ///   vector.transfer_write %v
232 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
233 ///     vector<8x16xf32>
234 /// ```
235 /// into:
236 /// ```
237 ///   %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
238 ///   vector.transfer_write %v1
239 ///     {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
240 ///     vector<1x8x16xf32>
241 /// ```
242 struct TransferWriteNonPermutationLowering
243     : public MaskableOpRewritePattern<vector::TransferWriteOp> {
244   using MaskableOpRewritePattern::MaskableOpRewritePattern;
245 
246   FailureOr<mlir::Value>
247   matchAndRewriteMaskableOp(vector::TransferWriteOp op,
248                             MaskingOpInterface maskOp,
249                             PatternRewriter &rewriter) const override {
250     // TODO: support 0-d corner case.
251     if (op.getTransferRank() == 0)
252       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
253     // TODO: Support transfer_write inside MaskOp case.
254     if (maskOp)
255       return rewriter.notifyMatchFailure(op, "Masked case not supported");
256 
257     SmallVector<unsigned> permutation;
258     AffineMap map = op.getPermutationMap();
259     if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
260       return rewriter.notifyMatchFailure(
261           op,
262           "map is already permutable to minor identity, apply another pattern");
263     }
264 
265     // Missing outer dimensions are allowed, find the most outer existing
266     // dimension then deduce the missing inner dimensions.
267     SmallVector<bool> foundDim(map.getNumDims(), false);
268     for (AffineExpr exp : map.getResults())
269       foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
270     SmallVector<AffineExpr> exprs;
271     bool foundFirstDim = false;
272     SmallVector<int64_t> missingInnerDim;
273     for (size_t i = 0; i < foundDim.size(); i++) {
274       if (foundDim[i]) {
275         foundFirstDim = true;
276         continue;
277       }
278       if (!foundFirstDim)
279         continue;
280       // Once we found one outer dimension existing in the map keep track of all
281       // the missing dimensions after that.
282       missingInnerDim.push_back(i);
283       exprs.push_back(rewriter.getAffineDimExpr(i));
284     }
285     // Vector: add unit dims at the beginning of the shape.
286     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
287                                     missingInnerDim.size());
288     // Mask: add unit dims at the end of the shape.
289     Value newMask;
290     if (op.getMask())
291       newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
292                                missingInnerDim.size());
293     exprs.append(map.getResults().begin(), map.getResults().end());
294     AffineMap newMap =
295         AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
296     // All the new dimensions added are inbound.
297     SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
298     for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
299       newInBoundsValues.push_back(op.isDimInBounds(i));
300     }
301     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
302     auto newWrite = rewriter.create<vector::TransferWriteOp>(
303         op.getLoc(), newVec, op.getSource(), op.getIndices(),
304         AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
305     if (newWrite.hasPureTensorSemantics())
306       return newWrite.getResult();
307     // In the memref case there's no return value. Use empty value to signal
308     // success.
309     return Value();
310   }
311 };
312 
313 /// Lower transfer_read op with broadcast in the leading dimensions into
314 /// transfer_read of lower rank + vector.broadcast.
315 /// Ex: vector.transfer_read ...
316 ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
317 /// into:
318 ///     %v = vector.transfer_read ...
319 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
320 ///     vector.broadcast %v
321 struct TransferOpReduceRank
322     : public MaskableOpRewritePattern<vector::TransferReadOp> {
323   using MaskableOpRewritePattern::MaskableOpRewritePattern;
324 
325   FailureOr<mlir::Value>
326   matchAndRewriteMaskableOp(vector::TransferReadOp op,
327                             MaskingOpInterface maskOp,
328                             PatternRewriter &rewriter) const override {
329     // TODO: support 0-d corner case.
330     if (op.getTransferRank() == 0)
331       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
332     // TODO: support masked case.
333     if (maskOp)
334       return rewriter.notifyMatchFailure(op, "Masked case not supported");
335 
336     AffineMap map = op.getPermutationMap();
337     unsigned numLeadingBroadcast = 0;
338     for (auto expr : map.getResults()) {
339       auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
340       if (!dimExpr || dimExpr.getValue() != 0)
341         break;
342       numLeadingBroadcast++;
343     }
344     // If there are no leading zeros in the map there is nothing to do.
345     if (numLeadingBroadcast == 0)
346       return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
347 
348     VectorType originalVecType = op.getVectorType();
349     unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
350     // Calculate new map, vector type and masks without the leading zeros.
351     AffineMap newMap = AffineMap::get(
352         map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
353         op.getContext());
354     // Only remove the leading zeros if the rest of the map is a minor identity
355     // with broadasting. Otherwise we first want to permute the map.
356     if (!newMap.isMinorIdentityWithBroadcasting()) {
357       return rewriter.notifyMatchFailure(
358           op, "map is not a minor identity with broadcasting");
359     }
360 
361     // TODO: support zero-dimension vectors natively.  See:
362     // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
363     // In the meantime, lower these to a scalar load when they pop up.
364     if (reducedShapeRank == 0) {
365       Value newRead;
366       if (isa<TensorType>(op.getShapedType())) {
367         newRead = rewriter.create<tensor::ExtractOp>(
368             op.getLoc(), op.getSource(), op.getIndices());
369       } else {
370         newRead = rewriter.create<memref::LoadOp>(
371             op.getLoc(), originalVecType.getElementType(), op.getSource(),
372             op.getIndices());
373       }
374       return rewriter
375           .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
376           .getVector();
377     }
378 
379     SmallVector<int64_t> newShape(
380         originalVecType.getShape().take_back(reducedShapeRank));
381     SmallVector<bool> newScalableDims(
382         originalVecType.getScalableDims().take_back(reducedShapeRank));
383     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
384     if (newShape.empty())
385       return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
386 
387     VectorType newReadType = VectorType::get(
388         newShape, originalVecType.getElementType(), newScalableDims);
389     ArrayAttr newInBoundsAttr =
390         op.getInBounds()
391             ? rewriter.getArrayAttr(
392                   op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
393             : ArrayAttr();
394     Value newRead = rewriter.create<vector::TransferReadOp>(
395         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
396         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
397         newInBoundsAttr);
398     return rewriter
399         .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
400         .getVector();
401   }
402 };
403 
404 } // namespace
405 
406 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
407     RewritePatternSet &patterns, PatternBenefit benefit) {
408   patterns
409       .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
410            TransferOpReduceRank, TransferWriteNonPermutationLowering>(
411           patterns.getContext(), benefit);
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // populateVectorTransferLoweringPatterns
416 //===----------------------------------------------------------------------===//
417 
418 namespace {
419 /// Progressive lowering of transfer_read. This pattern supports lowering of
420 /// `vector.transfer_read` to a combination of `vector.load` and
421 /// `vector.broadcast` if all of the following hold:
422 /// - Stride of most minor memref dimension must be 1.
423 /// - Out-of-bounds masking is not required.
424 /// - If the memref's element type is a vector type then it coincides with the
425 ///   result type.
426 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
427 struct TransferReadToVectorLoadLowering
428     : public MaskableOpRewritePattern<vector::TransferReadOp> {
429   TransferReadToVectorLoadLowering(MLIRContext *context,
430                                    std::optional<unsigned> maxRank,
431                                    PatternBenefit benefit = 1)
432       : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
433         maxTransferRank(maxRank) {}
434 
435   FailureOr<mlir::Value>
436   matchAndRewriteMaskableOp(vector::TransferReadOp read,
437                             MaskingOpInterface maskOp,
438                             PatternRewriter &rewriter) const override {
439     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
440       return rewriter.notifyMatchFailure(
441           read, "vector type is greater than max transfer rank");
442     }
443 
444     if (maskOp)
445       return rewriter.notifyMatchFailure(read, "Masked case not supported");
446     SmallVector<unsigned> broadcastedDims;
447     // Permutations are handled by VectorToSCF or
448     // populateVectorTransferPermutationMapLoweringPatterns.
449     // We let the 0-d corner case pass-through as it is supported.
450     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
451             &broadcastedDims))
452       return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
453 
454     auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
455     if (!memRefType)
456       return rewriter.notifyMatchFailure(read, "not a memref source");
457 
458     // Non-unit strides are handled by VectorToSCF.
459     if (!isLastMemrefDimUnitStride(memRefType))
460       return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
461 
462     // If there is broadcasting involved then we first load the unbroadcasted
463     // vector, and then broadcast it with `vector.broadcast`.
464     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
465     SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
466     for (unsigned i : broadcastedDims)
467       unbroadcastedVectorShape[i] = 1;
468     VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
469         unbroadcastedVectorShape, read.getVectorType().getElementType());
470 
471     // `vector.load` supports vector types as memref's elements only when the
472     // resulting vector type is the same as the element type.
473     auto memrefElTy = memRefType.getElementType();
474     if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
475       return rewriter.notifyMatchFailure(read, "incompatible element type");
476 
477     // Otherwise, element types of the memref and the vector must match.
478     if (!isa<VectorType>(memrefElTy) &&
479         memrefElTy != read.getVectorType().getElementType())
480       return rewriter.notifyMatchFailure(read, "non-matching element type");
481 
482     // Out-of-bounds dims are handled by MaterializeTransferMask.
483     if (read.hasOutOfBoundsDim())
484       return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
485 
486     // Create vector load op.
487     Operation *res;
488     if (read.getMask()) {
489       if (read.getVectorType().getRank() != 1)
490         // vector.maskedload operates on 1-D vectors.
491         return rewriter.notifyMatchFailure(
492             read, "vector type is not rank 1, can't create masked load, needs "
493                   "VectorToSCF");
494 
495       Value fill = rewriter.create<vector::SplatOp>(
496           read.getLoc(), unbroadcastedVectorType, read.getPadding());
497       res = rewriter.create<vector::MaskedLoadOp>(
498           read.getLoc(), unbroadcastedVectorType, read.getSource(),
499           read.getIndices(), read.getMask(), fill);
500     } else {
501       res = rewriter.create<vector::LoadOp>(
502           read.getLoc(), unbroadcastedVectorType, read.getSource(),
503           read.getIndices());
504     }
505 
506     // Insert a broadcasting op if required.
507     if (!broadcastedDims.empty())
508       res = rewriter.create<vector::BroadcastOp>(
509           read.getLoc(), read.getVectorType(), res->getResult(0));
510     return res->getResult(0);
511   }
512 
513   std::optional<unsigned> maxTransferRank;
514 };
515 
516 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
517 // TODO: we shouldn't cross the vector/scalar domains just for this
518 // but atm we lack the infra to avoid it. Possible solutions include:
519 // - go directly to LLVM + bitcast
520 // - introduce a bitcast op and likely a new pointer dialect
521 // - let memref.load/store additionally support the 0-d vector case
522 // There are still deeper data layout issues lingering even in this
523 // trivial case (for architectures for which this matters).
524 struct VectorLoadToMemrefLoadLowering
525     : public OpRewritePattern<vector::LoadOp> {
526   using OpRewritePattern::OpRewritePattern;
527 
528   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
529                                 PatternRewriter &rewriter) const override {
530     auto vecType = loadOp.getVectorType();
531     if (vecType.getNumElements() != 1)
532       return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
533 
534     auto memrefLoad = rewriter.create<memref::LoadOp>(
535         loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
536     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
537                                                      memrefLoad);
538     return success();
539   }
540 };
541 
542 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
543 struct VectorStoreToMemrefStoreLowering
544     : public OpRewritePattern<vector::StoreOp> {
545   using OpRewritePattern::OpRewritePattern;
546 
547   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
548                                 PatternRewriter &rewriter) const override {
549     auto vecType = storeOp.getVectorType();
550     if (vecType.getNumElements() != 1)
551       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
552 
553     Value extracted;
554     if (vecType.getRank() == 0) {
555       // TODO: Unifiy once ExtractOp supports 0-d vectors.
556       extracted = rewriter.create<vector::ExtractElementOp>(
557           storeOp.getLoc(), storeOp.getValueToStore());
558     } else {
559       SmallVector<int64_t> indices(vecType.getRank(), 0);
560       extracted = rewriter.create<vector::ExtractOp>(
561           storeOp.getLoc(), storeOp.getValueToStore(), indices);
562     }
563 
564     rewriter.replaceOpWithNewOp<memref::StoreOp>(
565         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
566     return success();
567   }
568 };
569 
570 /// Progressive lowering of transfer_write. This pattern supports lowering of
571 /// `vector.transfer_write` to `vector.store` if all of the following hold:
572 /// - Stride of most minor memref dimension must be 1.
573 /// - Out-of-bounds masking is not required.
574 /// - If the memref's element type is a vector type then it coincides with the
575 ///   type of the written value.
576 /// - The permutation map is the minor identity map (neither permutation nor
577 ///   broadcasting is allowed).
578 struct TransferWriteToVectorStoreLowering
579     : public MaskableOpRewritePattern<vector::TransferWriteOp> {
580   TransferWriteToVectorStoreLowering(MLIRContext *context,
581                                      std::optional<unsigned> maxRank,
582                                      PatternBenefit benefit = 1)
583       : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
584         maxTransferRank(maxRank) {}
585 
586   FailureOr<mlir::Value>
587   matchAndRewriteMaskableOp(vector::TransferWriteOp write,
588                             MaskingOpInterface maskOp,
589                             PatternRewriter &rewriter) const override {
590     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
591       return rewriter.notifyMatchFailure(
592           write, "vector type is greater than max transfer rank");
593     }
594     if (maskOp)
595       return rewriter.notifyMatchFailure(write, "Masked case not supported");
596 
597     // Permutations are handled by VectorToSCF or
598     // populateVectorTransferPermutationMapLoweringPatterns.
599     if ( // pass-through for the 0-d corner case.
600         !write.getPermutationMap().isMinorIdentity())
601       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
602         diag << "permutation map is not minor identity: " << write;
603       });
604 
605     auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
606     if (!memRefType)
607       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
608         diag << "not a memref type: " << write;
609       });
610 
611     // Non-unit strides are handled by VectorToSCF.
612     if (!isLastMemrefDimUnitStride(memRefType))
613       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
614         diag << "most minor stride is not 1: " << write;
615       });
616 
617     // `vector.store` supports vector types as memref's elements only when the
618     // type of the vector value being written is the same as the element type.
619     auto memrefElTy = memRefType.getElementType();
620     if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
621       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
622         diag << "elemental type mismatch: " << write;
623       });
624 
625     // Otherwise, element types of the memref and the vector must match.
626     if (!isa<VectorType>(memrefElTy) &&
627         memrefElTy != write.getVectorType().getElementType())
628       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
629         diag << "elemental type mismatch: " << write;
630       });
631 
632     // Out-of-bounds dims are handled by MaterializeTransferMask.
633     if (write.hasOutOfBoundsDim())
634       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
635         diag << "out of bounds dim: " << write;
636       });
637     if (write.getMask()) {
638       if (write.getVectorType().getRank() != 1)
639         // vector.maskedstore operates on 1-D vectors.
640         return rewriter.notifyMatchFailure(
641             write.getLoc(), [=](Diagnostic &diag) {
642               diag << "vector type is not rank 1, can't create masked store, "
643                       "needs VectorToSCF: "
644                    << write;
645             });
646 
647       rewriter.create<vector::MaskedStoreOp>(
648           write.getLoc(), write.getSource(), write.getIndices(),
649           write.getMask(), write.getVector());
650     } else {
651       rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
652                                        write.getSource(), write.getIndices());
653     }
654     // There's no return value for StoreOps. Use Value() to signal success to
655     // matchAndRewrite.
656     return Value();
657   }
658 
659   std::optional<unsigned> maxTransferRank;
660 };
661 } // namespace
662 
663 void mlir::vector::populateVectorTransferLoweringPatterns(
664     RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
665     PatternBenefit benefit) {
666   patterns.add<TransferReadToVectorLoadLowering,
667                TransferWriteToVectorStoreLowering>(patterns.getContext(),
668                                                    maxTransferRank, benefit);
669   patterns
670       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
671           patterns.getContext(), benefit);
672 }
673