xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (revision dedc7d4d362b8045c6810f8ca7f947bbdb63b7ec)
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <numeric>
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/TypeUtilities.h"
19 
20 #define DEBUG_TYPE "vector-drop-unit-dim"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 // Trims leading one dimensions from `oldType` and returns the result type.
26 // Returns `vector<1xT>` if `oldType` only has one element.
27 static VectorType trimLeadingOneDims(VectorType oldType) {
28   ArrayRef<int64_t> oldShape = oldType.getShape();
29   ArrayRef<int64_t> newShape = oldShape;
30 
31   ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
32   ArrayRef<bool> newScalableDims = oldScalableDims;
33 
34   while (!newShape.empty() && newShape.front() == 1 &&
35          !newScalableDims.front()) {
36     newShape = newShape.drop_front(1);
37     newScalableDims = newScalableDims.drop_front(1);
38   }
39 
40   // Make sure we have at least 1 dimension per vector type requirements.
41   if (newShape.empty()) {
42     newShape = oldShape.take_back();
43     newScalableDims = oldType.getScalableDims().take_back();
44   }
45   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
46 }
47 
48 /// Return a smallVector of size `rank` containing all zeros.
49 static SmallVector<int64_t> splatZero(int64_t rank) {
50   return SmallVector<int64_t>(rank, 0);
51 }
52 namespace {
53 
54 // Casts away leading one dimensions in vector.extract_strided_slice's vector
55 // input by inserting vector.broadcast.
56 struct CastAwayExtractStridedSliceLeadingOneDim
57     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
58   using OpRewritePattern::OpRewritePattern;
59 
60   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
61                                 PatternRewriter &rewriter) const override {
62     // vector.extract_strided_slice requires the input and output vector to have
63     // the same rank. Here we drop leading one dimensions from the input vector
64     // type to make sure we don't cause mismatch.
65     VectorType oldSrcType = extractOp.getSourceVectorType();
66     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
67 
68     if (newSrcType.getRank() == oldSrcType.getRank())
69       return failure();
70 
71     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
72 
73     VectorType oldDstType = extractOp.getType();
74     VectorType newDstType =
75         VectorType::get(oldDstType.getShape().drop_front(dropCount),
76                         oldDstType.getElementType());
77 
78     Location loc = extractOp.getLoc();
79 
80     Value newSrcVector = rewriter.create<vector::ExtractOp>(
81         loc, extractOp.getVector(), splatZero(dropCount));
82 
83     // The offsets/sizes/strides attribute can have a less number of elements
84     // than the input vector's rank: it is meant for the leading dimensions.
85     auto newOffsets = rewriter.getArrayAttr(
86         extractOp.getOffsets().getValue().drop_front(dropCount));
87     auto newSizes = rewriter.getArrayAttr(
88         extractOp.getSizes().getValue().drop_front(dropCount));
89     auto newStrides = rewriter.getArrayAttr(
90         extractOp.getStrides().getValue().drop_front(dropCount));
91 
92     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
93         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
94 
95     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
96                                                      newExtractOp);
97 
98     return success();
99   }
100 };
101 
102 // Casts away leading one dimensions in vector.insert_strided_slice's vector
103 // inputs by inserting vector.broadcast.
104 struct CastAwayInsertStridedSliceLeadingOneDim
105     : public OpRewritePattern<vector::InsertStridedSliceOp> {
106   using OpRewritePattern::OpRewritePattern;
107 
108   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
109                                 PatternRewriter &rewriter) const override {
110     VectorType oldSrcType = insertOp.getSourceVectorType();
111     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
112     VectorType oldDstType = insertOp.getDestVectorType();
113     VectorType newDstType = trimLeadingOneDims(oldDstType);
114 
115     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
116     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
117     if (srcDropCount == 0 && dstDropCount == 0)
118       return failure();
119 
120     // Trim leading one dimensions from both operands.
121     Location loc = insertOp.getLoc();
122 
123     Value newSrcVector = rewriter.create<vector::ExtractOp>(
124         loc, insertOp.getSource(), splatZero(srcDropCount));
125     Value newDstVector = rewriter.create<vector::ExtractOp>(
126         loc, insertOp.getDest(), splatZero(dstDropCount));
127 
128     auto newOffsets = rewriter.getArrayAttr(
129         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
130     auto newStrides = rewriter.getArrayAttr(
131         insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
132 
133     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
134         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
135 
136     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
137                                                      newInsertOp);
138 
139     return success();
140   }
141 };
142 
143 // Casts away leading one dimensions in vector.insert's vector inputs by
144 // inserting vector.broadcast.
145 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
146   using OpRewritePattern::OpRewritePattern;
147 
148   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
149                                 PatternRewriter &rewriter) const override {
150     Type oldSrcType = insertOp.getSourceType();
151     Type newSrcType = oldSrcType;
152     int64_t oldSrcRank = 0, newSrcRank = 0;
153     if (auto type = dyn_cast<VectorType>(oldSrcType)) {
154       newSrcType = trimLeadingOneDims(type);
155       oldSrcRank = type.getRank();
156       newSrcRank = cast<VectorType>(newSrcType).getRank();
157     }
158 
159     VectorType oldDstType = insertOp.getDestVectorType();
160     VectorType newDstType = trimLeadingOneDims(oldDstType);
161 
162     int64_t srcDropCount = oldSrcRank - newSrcRank;
163     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
164     if (srcDropCount == 0 && dstDropCount == 0)
165       return failure();
166 
167     // Trim leading one dimensions from both operands.
168     Location loc = insertOp.getLoc();
169 
170     Value newSrcVector = insertOp.getSource();
171     if (oldSrcRank != 0) {
172       newSrcVector = rewriter.create<vector::ExtractOp>(
173           loc, insertOp.getSource(), splatZero(srcDropCount));
174     }
175     Value newDstVector = rewriter.create<vector::ExtractOp>(
176         loc, insertOp.getDest(), splatZero(dstDropCount));
177 
178     // New position rank needs to be computed in two steps: (1) if destination
179     // type has leading unit dims, we also trim the position array accordingly,
180     // then (2) if source type also has leading unit dims, we need to append
181     // zeroes to the position array accordingly.
182     unsigned oldPosRank = insertOp.getNumIndices();
183     unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
184     SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
185     SmallVector<OpFoldResult> newPosition =
186         llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
187     newPosition.resize(newDstType.getRank() - newSrcRank,
188                        rewriter.getI64IntegerAttr(0));
189 
190     auto newInsertOp = rewriter.create<vector::InsertOp>(
191         loc, newSrcVector, newDstVector, newPosition);
192 
193     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
194                                                      newInsertOp);
195 
196     return success();
197   }
198 };
199 
200 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
201                                   VectorType newType, AffineMap newMap,
202                                   VectorType oldMaskType) {
203   // Infer the type of the new mask from the new map.
204   VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
205 
206   // If the new mask is broadcastable to the old result type, we can safely
207   // use a `vector.extract` to get the new mask. Otherwise the best we can
208   // do is shape cast.
209   if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
210       BroadcastableToResult::Success) {
211     int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
212     return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
213   }
214   return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
215 }
216 
217 // Turns vector.transfer_read on vector with leading 1 dimensions into
218 // vector.shape_cast followed by vector.transfer_read on vector without leading
219 // 1 dimensions.
220 struct CastAwayTransferReadLeadingOneDim
221     : public OpRewritePattern<vector::TransferReadOp> {
222   using OpRewritePattern::OpRewritePattern;
223 
224   LogicalResult matchAndRewrite(vector::TransferReadOp read,
225                                 PatternRewriter &rewriter) const override {
226     // TODO(#78787): Not supported masked op yet.
227     if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
228       return failure();
229     // TODO: support 0-d corner case.
230     if (read.getTransferRank() == 0)
231       return failure();
232 
233     auto shapedType = cast<ShapedType>(read.getSource().getType());
234     if (shapedType.getElementType() != read.getVectorType().getElementType())
235       return failure();
236 
237     VectorType oldType = read.getVectorType();
238     VectorType newType = trimLeadingOneDims(oldType);
239 
240     if (newType == oldType)
241       return failure();
242 
243     AffineMap oldMap = read.getPermutationMap();
244     ArrayRef<AffineExpr> newResults =
245         oldMap.getResults().take_back(newType.getRank());
246     AffineMap newMap =
247         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
248                        rewriter.getContext());
249 
250     ArrayAttr inBoundsAttr;
251     if (read.getInBounds())
252       inBoundsAttr = rewriter.getArrayAttr(
253           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
254 
255     Value mask = Value();
256     if (read.getMask()) {
257       VectorType maskType = read.getMaskType();
258       mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
259                                   newType, newMap, maskType);
260     }
261 
262     auto newRead = rewriter.create<vector::TransferReadOp>(
263         read.getLoc(), newType, read.getSource(), read.getIndices(),
264         AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
265     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
266 
267     return success();
268   }
269 };
270 
271 // Turns vector.transfer_write on vector with leading 1 dimensions into
272 // vector.shape_cast followed by vector.transfer_write on vector without leading
273 // 1 dimensions.
274 struct CastAwayTransferWriteLeadingOneDim
275     : public OpRewritePattern<vector::TransferWriteOp> {
276   using OpRewritePattern::OpRewritePattern;
277 
278   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
279                                 PatternRewriter &rewriter) const override {
280     // TODO(#78787): Not supported masked op yet.
281     if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
282       return failure();
283     // TODO: support 0-d corner case.
284     if (write.getTransferRank() == 0)
285       return failure();
286 
287     auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
288     if (shapedType.getElementType() != write.getVectorType().getElementType())
289       return failure();
290 
291     VectorType oldType = write.getVectorType();
292     VectorType newType = trimLeadingOneDims(oldType);
293     if (newType == oldType)
294       return failure();
295     int64_t dropDim = oldType.getRank() - newType.getRank();
296 
297     AffineMap oldMap = write.getPermutationMap();
298     ArrayRef<AffineExpr> newResults =
299         oldMap.getResults().take_back(newType.getRank());
300     AffineMap newMap =
301         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
302                        rewriter.getContext());
303 
304     ArrayAttr inBoundsAttr;
305     if (write.getInBounds())
306       inBoundsAttr = rewriter.getArrayAttr(
307           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
308 
309     auto newVector = rewriter.create<vector::ExtractOp>(
310         write.getLoc(), write.getVector(), splatZero(dropDim));
311 
312     if (write.getMask()) {
313       VectorType maskType = write.getMaskType();
314       Value newMask = dropUnitDimsFromMask(
315           rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
316       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
317           write, newVector, write.getSource(), write.getIndices(),
318           AffineMapAttr::get(newMap), newMask, inBoundsAttr);
319       return success();
320     }
321 
322     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
323         write, newVector, write.getSource(), write.getIndices(),
324         AffineMapAttr::get(newMap), inBoundsAttr);
325     return success();
326   }
327 };
328 
329 } // namespace
330 
331 LogicalResult
332 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
333                                                RewriterBase &rewriter) {
334   // TODO(#78787): Not supported masked op yet.
335   if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
336     return failure();
337   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
338   if (oldAccType == nullptr)
339     return failure();
340   if (oldAccType.getRank() < 2)
341     return failure();
342   if (oldAccType.getShape()[0] != 1)
343     return failure();
344   // currently we support only dropping one dim but the pattern can be applied
345   // greedily to drop more.
346   int64_t dropDim = 1;
347 
348   auto oldIndexingMaps = contractOp.getIndexingMapsArray();
349   SmallVector<AffineMap> newIndexingMaps;
350 
351   auto oldIteratorTypes = contractOp.getIteratorTypes();
352   SmallVector<Attribute> newIteratorTypes;
353 
354   int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
355 
356   if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
357     // only parallel type iterators can be dropped.
358     return failure();
359 
360   for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
361     int64_t currDim = it.index();
362     if (currDim == dimToDrop)
363       continue;
364     newIteratorTypes.push_back(it.value());
365   }
366 
367   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
368                                  contractOp.getAcc()};
369   SmallVector<Value> newOperands;
370 
371   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
372     // Check if the dim to be dropped exists as a leading dim in the operand
373     // if it does then we use vector.extract to drop it.
374     bool validExtract = false;
375     SmallVector<AffineExpr> results;
376     auto map = it.value();
377     int64_t orginalZeroDim = it.value().getDimPosition(0);
378     if (orginalZeroDim != dimToDrop) {
379       // There are two reasons to be in this path, 1. We need to
380       // tranpose the operand to make the dim to be dropped
381       // leading. 2. The dim to be dropped does not exist and in
382       // that case we dont want to add a unit tranpose but we must
383       // check all the indices to make sure this is the case.
384       bool tranposeNeeded = false;
385       SmallVector<int64_t> perm;
386       SmallVector<AffineExpr> transposeResults;
387 
388       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
389         int64_t currDim = map.getDimPosition(i);
390         if (currDim == dimToDrop) {
391           tranposeNeeded = true;
392           perm.insert(perm.begin(), i);
393           auto targetExpr = rewriter.getAffineDimExpr(currDim);
394           transposeResults.insert(transposeResults.begin(), targetExpr);
395         } else {
396           perm.push_back(i);
397           auto targetExpr = rewriter.getAffineDimExpr(currDim);
398           transposeResults.push_back(targetExpr);
399         }
400       }
401       // Do the tranpose now if needed so that we can drop the
402       // correct dim using extract later.
403       if (tranposeNeeded) {
404         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
405                              contractOp.getContext());
406         operands[it.index()] = rewriter.create<vector::TransposeOp>(
407             contractOp.getLoc(), operands[it.index()], perm);
408       }
409     }
410     // We have taken care to have the dim to be dropped be
411     // the leading dim. If its still not leading that means it
412     // does not exist in this operand and hence we do not need
413     // an extract.
414     if (map.getDimPosition(0) == dimToDrop)
415       validExtract = true;
416 
417     for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
418       int64_t currDim = map.getDimPosition(i);
419       if (currDim == dimToDrop)
420         // This is the dim we are dropping.
421         continue;
422       auto targetExpr = rewriter.getAffineDimExpr(
423           currDim < dimToDrop ? currDim : currDim - 1);
424       results.push_back(targetExpr);
425     }
426     newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
427                                              contractOp.getContext()));
428     // Extract if its a valid extraction, otherwise use the operand
429     // without extraction.
430     newOperands.push_back(
431         validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
432                                                           operands[it.index()],
433                                                           splatZero(dropDim))
434                      : operands[it.index()]);
435   }
436   auto newContractOp = rewriter.create<vector::ContractionOp>(
437       contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
438       rewriter.getAffineMapArrayAttr(newIndexingMaps),
439       rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
440   rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
441       contractOp, contractOp->getResultTypes()[0], newContractOp);
442   return success();
443 }
444 
445 namespace {
446 
447 /// Turns vector.contract on vector with leading 1 dimensions into
448 /// vector.extract followed by vector.contract on vector without leading
449 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
450 /// prior to extract.
451 struct CastAwayContractionLeadingOneDim
452     : public OpRewritePattern<vector::ContractionOp> {
453   using OpRewritePattern::OpRewritePattern;
454 
455   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
456                                 PatternRewriter &rewriter) const override {
457     return castAwayContractionLeadingOneDim(contractOp, rewriter);
458   }
459 };
460 
461 /// Looks at elementwise operations on vectors with at least one leading
462 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
463 /// and cast aways the leading one dimensions (_plural_) and then broadcasts
464 /// the results.
465 ///
466 /// Example before:
467 ///     %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
468 /// Example after:
469 ///    %2 = arith.mulf %0, %1 : vector<4x1xf32>
470 ///    %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
471 ///
472 /// Does support scalable vectors.
473 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
474 public:
475   CastAwayElementwiseLeadingOneDim(MLIRContext *context,
476                                    PatternBenefit benefit = 1)
477       : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
478 
479   LogicalResult matchAndRewrite(Operation *op,
480                                 PatternRewriter &rewriter) const override {
481     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
482       return failure();
483     auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
484     if (!vecType)
485       return failure();
486     VectorType newVecType = trimLeadingOneDims(vecType);
487     if (newVecType == vecType)
488       return failure();
489     int64_t dropDim = vecType.getRank() - newVecType.getRank();
490     SmallVector<Value, 4> newOperands;
491     for (Value operand : op->getOperands()) {
492       if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
493         newOperands.push_back(rewriter.create<vector::ExtractOp>(
494             op->getLoc(), operand, splatZero(dropDim)));
495       } else {
496         newOperands.push_back(operand);
497       }
498     }
499     Operation *newOp =
500         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
501                         newOperands, newVecType, op->getAttrs());
502     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
503                                                      newOp->getResult(0));
504     return success();
505   }
506 };
507 
508 // Drops leading 1 dimensions from vector.constant_mask and inserts a
509 // vector.broadcast back to the original shape.
510 struct CastAwayConstantMaskLeadingOneDim
511     : public OpRewritePattern<vector::ConstantMaskOp> {
512   using OpRewritePattern::OpRewritePattern;
513 
514   LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
515                                 PatternRewriter &rewriter) const override {
516     VectorType oldType = mask.getType();
517     VectorType newType = trimLeadingOneDims(oldType);
518 
519     if (newType == oldType)
520       return failure();
521 
522     int64_t dropDim = oldType.getRank() - newType.getRank();
523     SmallVector<int64_t> dimSizes;
524     for (auto attr : mask.getMaskDimSizes())
525       dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
526 
527     // If any of the dropped unit dims has a size of `0`, the entire mask is a
528     // zero mask, else the unit dim has no effect on the mask.
529     int64_t flatLeadingSize =
530         std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
531                         static_cast<int64_t>(1), std::multiplies<int64_t>());
532     SmallVector<int64_t> newDimSizes({flatLeadingSize});
533     newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
534 
535     auto newMask = rewriter.create<vector::ConstantMaskOp>(
536         mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
537     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
538     return success();
539   }
540 };
541 
542 } // namespace
543 
544 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
545     RewritePatternSet &patterns, PatternBenefit benefit) {
546   patterns
547       .add<CastAwayExtractStridedSliceLeadingOneDim,
548            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
549            CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
550            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
551            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
552   populateShapeCastFoldingPatterns(patterns, benefit);
553 }
554