xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <cassert>
16 #include <cstdint>
17 #include <functional>
18 #include <optional>
19 #include <type_traits>
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/Arith/Utils/Utils.h"
24 #include "mlir/Dialect/Linalg/IR/Linalg.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
30 #include "mlir/Dialect/Vector/IR/VectorOps.h"
31 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
32 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
33 #include "mlir/IR/BuiltinAttributeInterfaces.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/ImplicitLocOpBuilder.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Interfaces/VectorInterfaces.h"
41 
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/raw_ostream.h"
49 
50 #define DEBUG_TYPE "vector-to-vector"
51 
52 using namespace mlir;
53 using namespace mlir::vector;
54 
55 template <typename IntType>
56 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
57   return llvm::to_vector<4>(llvm::map_range(
58       arrayAttr.getAsRange<IntegerAttr>(),
59       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
60 }
61 
62 // Helper to find an index in an affine map.
63 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
64   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
65     int64_t idx = map.getDimPosition(i);
66     if (idx == index)
67       return i;
68   }
69   return std::nullopt;
70 }
71 
72 namespace {
73 
74 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
75 //
76 // Example:
77 //
78 //  The following MLIR with cancelling ShapeCastOps:
79 //
80 //   %0 = source : vector<5x4x2xf32>
81 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
82 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
83 //   %3 = user %2 : vector<5x4x2xf32>
84 //
85 //  Should canonicalize to the following:
86 //
87 //   %0 = source : vector<5x4x2xf32>
88 //   %1 = user %0 : vector<5x4x2xf32>
89 //
90 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
91   using OpRewritePattern::OpRewritePattern;
92 
93   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
94                                 PatternRewriter &rewriter) const override {
95     // Check if 'shapeCastOp' has vector source/result type.
96     auto sourceVectorType =
97         dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98     auto resultVectorType =
99         dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100     if (!sourceVectorType || !resultVectorType)
101       return failure();
102 
103     // Check if shape cast op source operand is also a shape cast op.
104     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105         shapeCastOp.getSource().getDefiningOp());
106     if (!sourceShapeCastOp)
107       return failure();
108     auto operandSourceVectorType =
109         cast<VectorType>(sourceShapeCastOp.getSource().getType());
110     auto operandResultVectorType = sourceShapeCastOp.getType();
111 
112     // Check if shape cast operations invert each other.
113     if (operandSourceVectorType != resultVectorType ||
114         operandResultVectorType != sourceVectorType)
115       return failure();
116 
117     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
118     return success();
119   }
120 };
121 
122 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
123 /// Ex:
124 /// ```
125 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
126 ///   %1 = vector.multi_reduction add, %0 [1]
127 ///     : vector<8x32x16xf32> to vector<8x16xf32>
128 /// ```
129 /// Gets converted to:
130 /// ```
131 ///   %1 = vector.contract {indexing_maps = [
132 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
133 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
135 ///    iterator_types = ["parallel", "parallel", "reduction"],
136 ///    kind = add} %0, %arg1, %cst_f0
137 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138 ///  ```
139 struct MultiReduceToContract
140     : public OpRewritePattern<vector::MultiDimReductionOp> {
141   using OpRewritePattern::OpRewritePattern;
142 
143   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
144                                 PatternRewriter &rewriter) const override {
145     if (reduceOp.getKind() != vector::CombiningKind::ADD)
146       return failure();
147     Operation *mulOp = reduceOp.getSource().getDefiningOp();
148     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
149       return failure();
150     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
151     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
152     SmallVector<AffineExpr> exprs;
153     SmallVector<vector::IteratorType> iteratorTypes;
154     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
155       if (!isReduceDim.value()) {
156         iteratorTypes.push_back(vector::IteratorType::parallel);
157         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
158       } else {
159         iteratorTypes.push_back(vector::IteratorType::reduction);
160       }
161     }
162     auto dstMap =
163         AffineMap::get(/*dimCount=*/reductionMask.size(),
164                        /*symbolCount=*/0, exprs, reduceOp.getContext());
165     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
166         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
167         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
168         rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
169             iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
170               return IteratorTypeAttr::get(rewriter.getContext(), t);
171             }))));
172     return success();
173   }
174 };
175 
176 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
177 /// Ex:
178 /// ```
179 ///   %0 = vector.transpose %arg0, [2, 0, 1]
180 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
181 ///   %1 = vector.contract {indexing_maps = [
182 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
183 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
184 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
185 ///    iterator_types = ["parallel", "parallel", "reduction"],
186 ///    kind = add} %0, %arg1, %cst_f0
187 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
188 /// ```
189 /// Gets converted to:
190 /// ```
191 ///   %1 = vector.contract {indexing_maps = [
192 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
193 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
194 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
195 ///    iterator_types = ["parallel", "parallel", "reduction"],
196 ///    kind = add} %arg0, %arg1, %cst_f0
197 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
198 ///  ```
199 struct CombineContractABTranspose final
200     : public OpRewritePattern<vector::ContractionOp> {
201   using OpRewritePattern::OpRewritePattern;
202 
203   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
204                                 PatternRewriter &rewriter) const override {
205     SmallVector<AffineMap> maps =
206         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
207     Value lhs = contractOp.getLhs();
208     Value rhs = contractOp.getRhs();
209     size_t index = 0;
210     bool changed = false;
211     for (Value *operand : {&lhs, &rhs}) {
212       AffineMap &map = maps[index++];
213       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
214       if (!transposeOp)
215         continue;
216       AffineMap permutationMap = AffineMap::getPermutationMap(
217           transposeOp.getPermutation(), contractOp.getContext());
218       map = inversePermutation(permutationMap).compose(map);
219       *operand = transposeOp.getVector();
220       changed = true;
221     }
222     if (!changed)
223       return failure();
224     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
225         contractOp, lhs, rhs, contractOp.getAcc(),
226         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
227     return success();
228   }
229 };
230 
231 /// Merges accumulator and result transposes into contract.
232 ///
233 /// For example:
234 /// ```mlir
235 /// %accT = vector.transpose %acc, [0, 2, 1]
236 ///   : vector<2x8x4xf32> to vector<2x4x8xf32>
237 /// %contract = vector.contract {
238 ///   indexing_maps = [
239 ///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
240 ///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
241 ///     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
242 ///   ],
243 ///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
244 ///   kind = #vector.kind<add>
245 /// } %lhs, %rhs, %accT
246 ///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
247 /// %0 = vector.transpose %contract, [0, 2, 1]
248 ///   : vector<2x4x8xf32> to vector<2x8x4>
249 /// ```
250 /// Becomes:
251 /// ```mlir
252 /// %0 = vector.contract {
253 ///   indexing_maps = [
254 ///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
255 ///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
256 ///     affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
257 ///   ],
258 ///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
259 ///   kind = #vector.kind<add>
260 /// } %lhs, %rhs, %acc
261 ///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
262 /// ```
263 struct CombineContractResultTranspose final
264     : public OpRewritePattern<vector::TransposeOp> {
265   using OpRewritePattern::OpRewritePattern;
266 
267   LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
268                                 PatternRewriter &rewriter) const override {
269     auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270     if (!contractOp || !contractOp->hasOneUse())
271       return failure();
272 
273     auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
274     if (!accTOp)
275       return failure();
276 
277     MLIRContext *context = contractOp.getContext();
278     auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
279     AffineMap contractMap = maps.back();
280 
281     // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
282     // To index into A in contract, we need revert(f)(g(C)) -> A.
283     auto accTMap =
284         AffineMap::getPermutationMap(accTOp.getPermutation(), context);
285 
286     // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
287     // To index into E in contract, we need h(g(C)) -> E.
288     auto resTMap =
289         AffineMap::getPermutationMap(resTOp.getPermutation(), context);
290     auto combinedResMap = resTMap.compose(contractMap);
291 
292     // The accumulator and result share the same indexing map. So they should be
293     // the same to be able to merge. This means combinedResMap is the same as
294     // inversePermutation(accTMap).compose(contractMap), which means
295     if (inversePermutation(accTMap) != resTMap)
296       return failure();
297     maps.back() = combinedResMap;
298 
299     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
300         resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
301         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
302     return success();
303   }
304 };
305 
306 /// Merge BroadcastOp into ContractionOp user.
307 /// Ex:
308 /// ```
309 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
310 ///   %1 = vector.contract {indexing_maps = [
311 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
312 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
313 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
314 ///    iterator_types = ["parallel", "parallel", "reduction"],
315 ///    kind = add} %0, %arg1, %cst_f0
316 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
317 /// ```
318 /// Gets converted to:
319 /// ```
320 ///   %1 = vector.contract {indexing_maps = [
321 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
322 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
323 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
324 ///    iterator_types = ["parallel", "parallel", "reduction"],
325 ///    kind = add} %arg0, %arg1, %cst_f0
326 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
327 ///  ```
328 struct CombineContractBroadcast
329     : public OpRewritePattern<vector::ContractionOp> {
330   using OpRewritePattern::OpRewritePattern;
331 
332   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
333                                 PatternRewriter &rewriter) const override {
334     SmallVector<AffineMap> maps =
335         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
336     Value lhs = contractOp.getLhs();
337     Value rhs = contractOp.getRhs();
338     size_t index = 0;
339     bool changed = false;
340     for (Value *operand : {&lhs, &rhs}) {
341       AffineMap &map = maps[index++];
342       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
343       if (!broadcast)
344         continue;
345       // contractionOp can only take vector as operands.
346       auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
347       if (!srcType ||
348           srcType.getRank() == broadcast.getResultVectorType().getRank())
349         continue;
350       int64_t rankDiff =
351           broadcast.getResultVectorType().getRank() - srcType.getRank();
352       bool innerDimBroadcast = false;
353       SmallVector<AffineExpr> originalDims;
354       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
355         if (dim.value() != broadcast.getResultVectorType().getDimSize(
356                                rankDiff + dim.index())) {
357           innerDimBroadcast = true;
358           break;
359         }
360         originalDims.push_back(
361             rewriter.getAffineDimExpr(dim.index() + rankDiff));
362       }
363       // Contract doesn't support inner dimension broadcast. Once this is
364       // relaxed we can remove this case.
365       if (innerDimBroadcast)
366         continue;
367 
368       // It would be incorrect to fold a broadcast onto a reduction dimension
369       // of non-unit size.
370       bool nonUnitDimReductionBroadcast = false;
371       for (int64_t i = 0; i < rankDiff; ++i) {
372         if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
373             isReductionIterator(contractOp.getIteratorTypes()
374                                     .getValue()[map.getDimPosition(i)])) {
375           nonUnitDimReductionBroadcast = true;
376           break;
377         }
378       }
379       if (nonUnitDimReductionBroadcast)
380         continue;
381 
382       AffineMap broadcastMap =
383           AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
384                          originalDims, contractOp.getContext());
385       map = broadcastMap.compose(map);
386       *operand = broadcast.getSource();
387       changed = true;
388     }
389 
390     if (!changed)
391       return failure();
392 
393     // Determine which dims are usused, now that the maps have been composed
394     // with the broadcast maps.
395     llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
396     // Compress unused dims.
397     for (auto &m : maps)
398       m = compressDims(m, unusedDimsBitVector);
399     // Compute the combined iterators.
400     SmallVector<Attribute> iterators;
401     for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402       if (!unusedDimsBitVector.test(i))
403         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
404     }
405     // Check that compressing unused dims isn't removing all reduction dimension
406     // pairs. For example, if the vector.contract had only one reduction
407     // iterator and that was a unit-dimension created by a broadcast,
408     // then we should bail here, otherwise we would create a contract without
409     // a reduction dimension pair.
410     bool hasReductionIteratorApplyingOnBothSides = false;
411     for (unsigned i = 0; i < iterators.size(); ++i) {
412       if (!isReductionIterator(iterators[i]))
413         continue;
414       if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
415         hasReductionIteratorApplyingOnBothSides = true;
416         break;
417       }
418     }
419     if (!hasReductionIteratorApplyingOnBothSides)
420       return failure();
421 
422     // If the compressed maps have a dimension that is not used by either LHS or
423     // RHS then the ContractionOp verifier would fail.
424     if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
425       return failure();
426     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
427         contractOp, lhs, rhs, contractOp.getAcc(),
428         rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
429     return success();
430   }
431 };
432 
433 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
434 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
435 /// casting ops are around these operations.
436 /// Ex:
437 /// ```
438 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
439 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
440 /// ```
441 /// Gets converted to:
442 /// ```
443 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
444 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
445 /// ```
446 struct ReorderCastOpsOnBroadcast
447     : public OpInterfaceRewritePattern<CastOpInterface> {
448   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
449 
450   LogicalResult matchAndRewrite(CastOpInterface op,
451                                 PatternRewriter &rewriter) const override {
452     if (op->getNumOperands() != 1)
453       return failure();
454     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
455     if (!bcastOp)
456       return failure();
457 
458     Type castResTy = getElementTypeOrSelf(op->getResult(0));
459     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
460       castResTy = vecTy.clone(castResTy);
461     auto *castOp =
462         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
463                         bcastOp.getSource(), castResTy, op->getAttrs());
464     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
465         op, op->getResult(0).getType(), castOp->getResult(0));
466     return success();
467   }
468 };
469 
470 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
471 /// transpose ops and contraction ops closer, which kicks in
472 /// CombineContractABTranspose pattern when elementwise ops are between these
473 /// operations. Ex:
474 /// ```
475 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
476 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
477 /// %r = arith.addf %at, %bt : vector<2x4xf32>
478 /// ```
479 /// Gets converted to:
480 /// ```
481 /// %0 = arith.addf %a, %b : vector<4x2xf32>
482 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
483 /// ```
484 struct ReorderElementwiseOpsOnTranspose final
485     : public OpTraitRewritePattern<OpTrait::Elementwise> {
486   using OpTraitRewritePattern::OpTraitRewritePattern;
487   LogicalResult matchAndRewrite(Operation *op,
488                                 PatternRewriter &rewriter) const override {
489     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
490       return failure();
491 
492     // Make sure all operands are transpose/constant ops and collect their
493     // transposition maps.
494     SmallVector<ArrayRef<int64_t>> transposeMaps;
495     transposeMaps.reserve(op->getNumOperands());
496     // Record the initial type before transposition. We'll use its shape later.
497     // Any type will do here as we will check all transpose maps are the same.
498     VectorType srcType;
499     for (Value operand : op->getOperands()) {
500       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
501       if (transposeOp) {
502         transposeMaps.push_back(transposeOp.getPermutation());
503         srcType = transposeOp.getSourceVectorType();
504       } else if (!matchPattern(operand, m_Constant())) {
505         return failure();
506       }
507     }
508     if (transposeMaps.empty())
509       return failure();
510     // This is an elementwise op, so all transposed operands should have the
511     // same type. We need to additionally check that all transposes uses the
512     // same map.
513     if (!llvm::all_equal(transposeMaps))
514       return rewriter.notifyMatchFailure(op, "different transpose map");
515 
516     SmallVector<Value> srcValues;
517     srcValues.reserve(op->getNumOperands());
518 
519     // If there are constant operands, we need to insert inverse transposes for
520     // them. Calculate the inverse order first.
521     auto order = transposeMaps.front();
522     SmallVector<int64_t> invOrder(order.size());
523     for (int i = 0, e = order.size(); i < e; ++i)
524       invOrder[order[i]] = i;
525 
526     for (Value operand : op->getOperands()) {
527       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
528       if (transposeOp) {
529         srcValues.push_back(transposeOp.getVector());
530       } else {
531         // This is a constant. Create a reverse transpose op for it.
532         auto vectorType =
533             srcType.clone(cast<VectorType>(operand.getType()).getElementType());
534         srcValues.push_back(rewriter.create<vector::TransposeOp>(
535             operand.getLoc(), vectorType, operand, invOrder));
536       }
537     }
538 
539     auto vectorType = srcType.clone(
540         cast<VectorType>(op->getResultTypes()[0]).getElementType());
541     Operation *elementwiseOp =
542         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
543                         vectorType, op->getAttrs());
544     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
545         op, op->getResultTypes()[0], elementwiseOp->getResult(0),
546         transposeMaps.front());
547     return success();
548   }
549 };
550 
551 // Returns the values in `arrayAttr` as an integer vector.
552 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
553   return llvm::to_vector<4>(
554       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
555                       [](IntegerAttr attr) { return attr.getInt(); }));
556 }
557 
558 // Shuffles vector.bitcast op after vector.extract op.
559 //
560 // This transforms IR like:
561 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
562 //   %1 = vector.extract %0[3] : f16 from vector<8xf16>
563 // Into:
564 //   %0 = vector.extract %src[1] : f32 from vector<4xf32>
565 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
566 //   %2 = vector.extract %1[1] : f16 from vector<2xf16>
567 struct BubbleDownVectorBitCastForExtract
568     : public OpRewritePattern<vector::ExtractOp> {
569   using OpRewritePattern::OpRewritePattern;
570 
571   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
572                                 PatternRewriter &rewriter) const override {
573     // Only support extracting scalars for now.
574     if (extractOp.getSourceVectorType().getRank() != 1)
575       return failure();
576 
577     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
578     if (!castOp)
579       return failure();
580 
581     VectorType castSrcType = castOp.getSourceVectorType();
582     VectorType castDstType = castOp.getResultVectorType();
583     assert(castSrcType.getRank() == castDstType.getRank());
584 
585     // Fail to match if we only have one element in the cast op source.
586     // This is to avoid infinite loop given that this pattern can generate
587     // such cases.
588     if (castSrcType.getNumElements() == 1)
589       return failure();
590 
591     // Only support casting to a larger number of elements or now.
592     // E.g., vector<4xf32> -> vector<8xf16>.
593     if (castSrcType.getNumElements() > castDstType.getNumElements())
594       return failure();
595 
596     unsigned expandRatio =
597         castDstType.getNumElements() / castSrcType.getNumElements();
598 
599     // Get the first element of the mixed position as integer.
600     auto mixedPos = extractOp.getMixedPosition();
601     if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
602       return failure();
603     uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
604 
605     // Get the single scalar (as a vector) in the source value that packs the
606     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
607     Location loc = extractOp.getLoc();
608     Value packedValue = rewriter.create<vector::ExtractOp>(
609         loc, castOp.getSource(), index / expandRatio);
610     Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
611     Value zero = rewriter.create<arith::ConstantOp>(
612         loc, packedVecType, rewriter.getZeroAttr(packedVecType));
613     packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
614                                                     /*position=*/0);
615 
616     // Cast it to a vector with the desired scalar's type.
617     // E.g. f32 -> vector<2xf16>
618     VectorType packedType =
619         VectorType::get({expandRatio}, castDstType.getElementType());
620     Value castedValue =
621         rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
622 
623     // Finally extract the desired scalar.
624     rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
625                                                    index % expandRatio);
626     return success();
627   }
628 };
629 
630 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
631 //
632 // This transforms IR like:
633 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
634 //     %0 = vector.extract_strided_slice %cast {
635 //            offsets = [4], sizes = [4], strides = [1]
636 //          } : vector<8xf16> to vector<4xf16>
637 // Into:
638 //   %0 = vector.extract_strided_slice %src {
639 //          offsets = [2], sizes = [2], strides = [1]
640 //        } : vector<4xf32> to vector<2xf32>
641 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
642 struct BubbleDownBitCastForStridedSliceExtract
643     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
644   using OpRewritePattern::OpRewritePattern;
645 
646   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
647                                 PatternRewriter &rewriter) const override {
648     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
649     if (!castOp)
650       return failure();
651 
652     VectorType castSrcType = castOp.getSourceVectorType();
653     VectorType castDstType = castOp.getResultVectorType();
654     assert(castSrcType.getRank() == castDstType.getRank());
655 
656     int64_t castSrcLastDim = castSrcType.getShape().back();
657     int64_t castDstLastDim = castDstType.getShape().back();
658     // Require casting to more elements for now; other cases to be implemented.
659     if (castSrcLastDim > castDstLastDim)
660       return failure();
661 
662     // Only accept all one strides for now.
663     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
664                      [](const APInt &val) { return !val.isOne(); }))
665       return failure();
666 
667     unsigned rank = extractOp.getSourceVectorType().getRank();
668     assert(castDstLastDim % castSrcLastDim == 0);
669     int64_t expandRatio = castDstLastDim / castSrcLastDim;
670 
671     // If we have a less number of offsets than the rank, then implicitly we
672     // are selecting the full range for the last bitcasted dimension; other
673     // dimensions aren't affected. Otherwise, we need to scale down the last
674     // dimension's offset given we are extracting from less elements now.
675     ArrayAttr newOffsets = extractOp.getOffsets();
676     if (newOffsets.size() == rank) {
677       SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
678       if (offsets.back() % expandRatio != 0)
679         return failure();
680       offsets.back() = offsets.back() / expandRatio;
681       newOffsets = rewriter.getI64ArrayAttr(offsets);
682     }
683 
684     // Similarly for sizes.
685     ArrayAttr newSizes = extractOp.getSizes();
686     if (newSizes.size() == rank) {
687       SmallVector<int64_t> sizes = getIntValueVector(newSizes);
688       if (sizes.back() % expandRatio != 0)
689         return failure();
690       sizes.back() = sizes.back() / expandRatio;
691       newSizes = rewriter.getI64ArrayAttr(sizes);
692     }
693 
694     SmallVector<int64_t> dims =
695         llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
696     dims.back() = dims.back() / expandRatio;
697     VectorType newExtractType =
698         VectorType::get(dims, castSrcType.getElementType());
699 
700     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
701         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
702         newSizes, extractOp.getStrides());
703 
704     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
705         extractOp, extractOp.getType(), newExtractOp);
706 
707     return success();
708   }
709 };
710 
711 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
712 //
713 // This transforms IR like:
714 //   %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
715 //   %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
716 // Into:
717 //   %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
718 //   %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
719 //   %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
720 //
721 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
722   using OpRewritePattern::OpRewritePattern;
723 
724   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
725                                 PatternRewriter &rewriter) const override {
726     VectorType castSrcType = bitcastOp.getSourceVectorType();
727     VectorType castDstType = bitcastOp.getResultVectorType();
728 
729     // 0-D and scalable vectors are not supported yet.
730     if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
731         castDstType.isScalable())
732       return failure();
733 
734     int64_t castSrcLastDim = castSrcType.getShape().back();
735     int64_t castDstLastDim = castDstType.getShape().back();
736     bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
737     int64_t ratio;
738     if (isNumElemsShrink) {
739       assert(castSrcLastDim % castDstLastDim == 0);
740       ratio = castSrcLastDim / castDstLastDim;
741     } else {
742       assert(castDstLastDim % castSrcLastDim == 0);
743       ratio = castDstLastDim / castSrcLastDim;
744     }
745 
746     auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
747     if (!insertOp)
748       return failure();
749 
750     // Only vector sources are supported for now.
751     auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
752     if (!insertSrcType)
753       return failure();
754 
755     // Bitcast the source.
756     SmallVector<int64_t> srcDims(insertSrcType.getShape());
757     srcDims.back() =
758         isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
759     VectorType newCastSrcType =
760         VectorType::get(srcDims, castDstType.getElementType());
761     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
762         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
763 
764     SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
765     dstDims.back() =
766         isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
767     VectorType newCastDstType =
768         VectorType::get(dstDims, castDstType.getElementType());
769 
770     // Bitcast the destination.
771     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
772         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
773 
774     // Generate new insert.
775     rewriter.replaceOpWithNewOp<vector::InsertOp>(
776         bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
777     return success();
778   }
779 };
780 
781 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
782 //
783 // This transforms IR like:
784 //   %0 = vector.insert_strided_slice %src, %dst {
785 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
786 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
787 // Into:
788 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
789 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
790 //   %2 = vector.insert_strided_slice %src, %dst {
791 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
792 struct BubbleUpBitCastForStridedSliceInsert
793     : public OpRewritePattern<vector::BitCastOp> {
794   using OpRewritePattern::OpRewritePattern;
795 
796   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
797                                 PatternRewriter &rewriter) const override {
798     VectorType castSrcType = bitcastOp.getSourceVectorType();
799     VectorType castDstType = bitcastOp.getResultVectorType();
800     assert(castSrcType.getRank() == castDstType.getRank());
801     // Skip 0-D vector which will not from InsertStridedSliceOp.
802     if (castSrcType.getRank() == 0)
803       return failure();
804 
805     int64_t castSrcLastDim = castSrcType.getShape().back();
806     int64_t castDstLastDim = castDstType.getShape().back();
807     // Require casting to less elements for now; other cases to be implemented.
808     if (castSrcLastDim < castDstLastDim)
809       return failure();
810 
811     assert(castSrcLastDim % castDstLastDim == 0);
812     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
813 
814     auto insertOp =
815         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
816     if (!insertOp)
817       return failure();
818 
819     // Only accept all one strides for now.
820     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
821                      [](const APInt &val) { return !val.isOne(); }))
822       return failure();
823 
824     unsigned rank = insertOp.getSourceVectorType().getRank();
825     // Require insert op to have the same rank for the source and destination
826     // vector; other cases to be implemented.
827     if (rank != insertOp.getDestVectorType().getRank())
828       return failure();
829 
830     // Requires that shape of insert op src is castable to dstType.
831     unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
832     unsigned destinationWidth =
833         castDstType.getElementType().getIntOrFloatBitWidth();
834     unsigned numElements = destinationWidth / sourceWidth;
835     if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
836       return failure();
837 
838     ArrayAttr newOffsets = insertOp.getOffsets();
839     assert(newOffsets.size() == rank);
840     SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
841     if (offsets.back() % shrinkRatio != 0)
842       return failure();
843     offsets.back() = offsets.back() / shrinkRatio;
844     newOffsets = rewriter.getI64ArrayAttr(offsets);
845 
846     SmallVector<int64_t> srcDims =
847         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
848     srcDims.back() = srcDims.back() / shrinkRatio;
849     VectorType newCastSrcType =
850         VectorType::get(srcDims, castDstType.getElementType());
851 
852     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
853         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
854 
855     SmallVector<int64_t> dstDims =
856         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
857     dstDims.back() = dstDims.back() / shrinkRatio;
858     VectorType newCastDstType =
859         VectorType::get(dstDims, castDstType.getElementType());
860 
861     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
862         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
863 
864     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
865         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
866         insertOp.getStrides());
867 
868     return success();
869   }
870 };
871 
872 // Breaks down vector.bitcast op
873 //
874 // This transforms IR like:
875 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
876 // Into:
877 //   %cst = vector.splat %c0_f32 : vector<4xf32>
878 //   %1 = vector.extract_strided_slice %0 {
879 //          offsets = [0], sizes = [4], strides = [1]
880 //        } : vector<8xf16> to vector<4xf16>
881 //   %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
882 //   %4 = vector.insert_strided_slice %2, %cst {
883 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
884 //   %5 = vector.extract_strided_slice %0 {
885 //          offsets = [4], sizes = [4], strides = [1]
886 //        } : vector<8xf16> to vector<4xf16>
887 //   %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
888 //   %7 = vector.insert_strided_slice %6, %cst {
889 //          offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
890 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
891   using OpRewritePattern::OpRewritePattern;
892 
893 public:
894   BreakDownVectorBitCast(MLIRContext *context,
895                          std::function<bool(vector::BitCastOp)> controlFn,
896                          PatternBenefit benefit)
897       : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
898 
899   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
900                                 PatternRewriter &rewriter) const override {
901 
902     if (controlFn && !controlFn(bitcastOp))
903       return failure();
904 
905     VectorType castSrcType = bitcastOp.getSourceVectorType();
906     VectorType castDstType = bitcastOp.getResultVectorType();
907     assert(castSrcType.getRank() == castDstType.getRank());
908 
909     // This transformation builds on top of
910     // vector.{extract|insert}_strided_slice, which do not support
911     // extracting/inserting "scallable sub-vectors". Bail out.
912     if (castSrcType.isScalable())
913       return rewriter.notifyMatchFailure(bitcastOp,
914                                          "Scalable vectors are not supported");
915 
916     // Only support rank 1 case for now.
917     if (castSrcType.getRank() != 1)
918       return failure();
919 
920     int64_t castSrcLastDim = castSrcType.getShape().back();
921     int64_t castDstLastDim = castDstType.getShape().back();
922     // Require casting to less elements for now; other cases to be implemented.
923     if (castSrcLastDim < castDstLastDim)
924       return failure();
925 
926     assert(castSrcLastDim % castDstLastDim == 0);
927     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
928     // Nothing to do if it is already bitcasting to a single element.
929     if (castSrcLastDim == shrinkRatio)
930       return failure();
931 
932     Location loc = bitcastOp.getLoc();
933     Type elemType = castDstType.getElementType();
934     assert(elemType.isSignlessIntOrIndexOrFloat());
935 
936     Value zero = rewriter.create<arith::ConstantOp>(
937         loc, elemType, rewriter.getZeroAttr(elemType));
938     Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
939 
940     SmallVector<int64_t> sliceShape = {castDstLastDim};
941     SmallVector<int64_t> strides = {1};
942     VectorType newCastDstType =
943         VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
944                         castDstType.getElementType());
945 
946     for (int i = 0, e = shrinkRatio; i < e; ++i) {
947       Value extracted = rewriter.create<ExtractStridedSliceOp>(
948           loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
949           sliceShape, strides);
950       Value bitcast =
951           rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
952       res = rewriter.create<InsertStridedSliceOp>(
953           loc, bitcast, res,
954           ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
955     }
956     rewriter.replaceOp(bitcastOp, res);
957     return success();
958   }
959 
960 private:
961   std::function<bool(BitCastOp)> controlFn;
962 };
963 
964 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
965 /// ```
966 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
967 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
968 /// %r = arith.addi %a, %b : vector<1x4xindex>
969 /// ```
970 /// Gets converted to:
971 /// ```
972 /// %r = arith.addi %arg0, %arg1 : index
973 /// %b = vector.broadcast %r : index to vector<1x4xindex>
974 /// ```
975 ///
976 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
977 /// ops.
978 struct ReorderElementwiseOpsOnBroadcast final
979     : public OpTraitRewritePattern<OpTrait::Elementwise> {
980   using OpTraitRewritePattern::OpTraitRewritePattern;
981   LogicalResult matchAndRewrite(Operation *op,
982                                 PatternRewriter &rewriter) const override {
983     if (op->getNumResults() != 1)
984       return failure();
985     if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
986       return failure();
987     if (!OpTrait::hasElementwiseMappableTraits(op))
988       return rewriter.notifyMatchFailure(
989           op, "Op doesn't have ElementwiseMappableTraits");
990     if (op->getNumOperands() == 0)
991       return failure();
992     if (op->getResults()[0].getType() != op->getOperand(0).getType())
993       return rewriter.notifyMatchFailure(op,
994                                          "result and operand type mismatch");
995     if (isa<vector::FMAOp>(op)) {
996       return rewriter.notifyMatchFailure(
997           op,
998           "Op only accepts vector types - not supported as broadcast source "
999           "might be a scalar");
1000     }
1001 
1002     // Get the type of the lhs operand
1003     auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
1004     if (!lhsBcastOrSplat ||
1005         !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1006       return failure();
1007     auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1008 
1009     // Make sure that all operands are broadcast from identical types:
1010     //  * scalar (`vector.broadcast` + `vector.splat`), or
1011     //  * vector (`vector.broadcast`).
1012     // Otherwise the re-ordering wouldn't be safe.
1013     if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
1014           auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1015           if (bcast)
1016             return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1017           auto splat = val.getDefiningOp<vector::SplatOp>();
1018           if (splat)
1019             return (splat.getOperand().getType() == lhsBcastOrSplatType);
1020           return false;
1021         })) {
1022       return failure();
1023     }
1024 
1025     // Collect the source values before broadcasting
1026     SmallVector<Value> srcValues;
1027     srcValues.reserve(op->getNumOperands());
1028     for (Value operand : op->getOperands()) {
1029       srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1030     }
1031 
1032     // Create the "elementwise" Op
1033     Operation *elementwiseOp =
1034         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1035                         lhsBcastOrSplatType, op->getAttrs());
1036 
1037     // Replace the original Op with the elementwise Op
1038     auto vectorType = op->getResultTypes()[0];
1039     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1040         op, vectorType, elementwiseOp->getResults());
1041 
1042     return success();
1043   }
1044 };
1045 
1046 // Helper that returns a vector comparison that constructs a mask:
1047 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1048 //
1049 // If `dim == 0` then the result will be a 0-D vector.
1050 //
1051 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1052 //       much more compact, IR for this operation, but LLVM eventually
1053 //       generates more elaborate instructions for this intrinsic since it
1054 //       is very conservative on the boundary conditions.
1055 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1056                                    bool force32BitVectorIndices, int64_t dim,
1057                                    Value b, Value *off = nullptr) {
1058   auto loc = op->getLoc();
1059   // If we can assume all indices fit in 32-bit, we perform the vector
1060   // comparison in 32-bit to get a higher degree of SIMD parallelism.
1061   // Otherwise we perform the vector comparison using 64-bit indices.
1062   Type idxType =
1063       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1064   DenseIntElementsAttr indicesAttr;
1065   if (dim == 0 && force32BitVectorIndices) {
1066     indicesAttr = DenseIntElementsAttr::get(
1067         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1068   } else if (dim == 0) {
1069     indicesAttr = DenseIntElementsAttr::get(
1070         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1071   } else if (force32BitVectorIndices) {
1072     indicesAttr = rewriter.getI32VectorAttr(
1073         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1074   } else {
1075     indicesAttr = rewriter.getI64VectorAttr(
1076         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1077   }
1078   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1079   // Add in an offset if requested.
1080   if (off) {
1081     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1082     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1083     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1084   }
1085   // Construct the vector comparison.
1086   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1087   Value bounds =
1088       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1089   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1090                                         bounds);
1091 }
1092 
1093 template <typename ConcreteOp>
1094 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1095 public:
1096   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1097                                    PatternBenefit benefit = 1)
1098       : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1099         force32BitVectorIndices(enableIndexOpt) {}
1100 
1101   LogicalResult matchAndRewrite(ConcreteOp xferOp,
1102                                 PatternRewriter &rewriter) const override {
1103     if (!xferOp.hasOutOfBoundsDim())
1104       return failure();
1105 
1106     if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1107       return failure();
1108 
1109     Location loc = xferOp->getLoc();
1110     VectorType vtp = xferOp.getVectorType();
1111 
1112     // Create the in-bounds mask with all elements between [0 .. dim - offset)
1113     // set and [dim - offset .. vector_length) unset.
1114     //
1115     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1116     //       dimensions here.
1117     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1118     Value off = xferOp.getIndices()[lastIndex];
1119     Value dim =
1120         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1121     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1122     Value mask = rewriter.create<vector::CreateMaskOp>(
1123         loc,
1124         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1125                         vtp.getScalableDims()),
1126         b);
1127     if (xferOp.getMask()) {
1128       // Intersect the in-bounds with the mask specified as an op parameter.
1129       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1130     }
1131 
1132     rewriter.modifyOpInPlace(xferOp, [&]() {
1133       xferOp.getMaskMutable().assign(mask);
1134       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1135     });
1136 
1137     return success();
1138   }
1139 
1140 private:
1141   const bool force32BitVectorIndices;
1142 };
1143 
1144 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1145 class VectorCreateMaskOpConversion
1146     : public OpRewritePattern<vector::CreateMaskOp> {
1147 public:
1148   explicit VectorCreateMaskOpConversion(MLIRContext *context,
1149                                         bool enableIndexOpt,
1150                                         PatternBenefit benefit = 1)
1151       : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1152         force32BitVectorIndices(enableIndexOpt) {}
1153 
1154   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1155                                 PatternRewriter &rewriter) const override {
1156     auto dstType = op.getType();
1157     if (cast<VectorType>(dstType).isScalable())
1158       return failure();
1159     int64_t rank = dstType.getRank();
1160     if (rank > 1)
1161       return failure();
1162     rewriter.replaceOp(
1163         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1164                                   rank == 0 ? 0 : dstType.getDimSize(0),
1165                                   op.getOperand(0)));
1166     return success();
1167   }
1168 
1169 private:
1170   const bool force32BitVectorIndices;
1171 };
1172 
1173 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1174 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1175   auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1176   // TODO: Support non-dense constant.
1177   if (!denseAttr)
1178     return false;
1179 
1180   assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1181   return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1182 }
1183 
1184 /// Folds a select operation between an all-true and all-false vector. For now,
1185 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1186 ///
1187 ///   %true = arith.constant dense<true> : vector<1xi1>
1188 ///   %false = arith.constant dense<false> : vector<1xi1>
1189 ///   %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1190 ///   =>
1191 ///   %result = vector.broadcast %cond : i1 to vector<1xi1>
1192 ///
1193 /// InstCombine seems to handle vectors with multiple elements but not the
1194 /// single element ones.
1195 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1196   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1197 
1198   LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1199                                 PatternRewriter &rewriter) const override {
1200     auto vecType = dyn_cast<VectorType>(selectOp.getType());
1201     if (!vecType || !vecType.getElementType().isInteger(1))
1202       return failure();
1203 
1204     // Only scalar conditions can be folded.
1205     Value cond = selectOp.getCondition();
1206     if (isa<VectorType>(cond.getType()))
1207       return failure();
1208 
1209     // TODO: Support n-D and scalable vectors.
1210     if (vecType.getRank() != 1 || vecType.isScalable())
1211       return failure();
1212 
1213     // TODO: Support vectors with multiple elements.
1214     if (vecType.getShape()[0] != 1)
1215       return failure();
1216 
1217     auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1218     if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1219       return failure();
1220 
1221     auto falseConst =
1222         selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1223     if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1224       return failure();
1225 
1226     // Replace select with its condition broadcasted to single element vector.
1227     auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1228     auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1229     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1230     return success();
1231   }
1232 };
1233 
1234 /// Returns the number of dims can be folded away from transfer ops. It returns
1235 /// a failure if it can not determine the number of dims to be folded.
1236 ///
1237 /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1238 /// `vectorType` is vector<16x16x1x1xf32>
1239 /// (there two inner most dims can be dropped by memref.subview ops)
1240 ///
1241 /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1242 /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1243 /// (only the inner most unit dim of `srcType` can be dropped)
1244 ///
1245 /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1246 /// `vectorType` is vector<16x16x1x[1]xf32>
1247 /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1248 /// unit")
1249 static FailureOr<size_t>
1250 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1251   SmallVector<int64_t> srcStrides;
1252   int64_t srcOffset;
1253   if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1254     return failure();
1255 
1256   auto isUnitDim = [](VectorType type, int dim) {
1257     return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1258   };
1259 
1260   // According to vector.transfer_read/write semantics, the vector can be a
1261   // slice. Thus, we have to offset the check index with `rankDiff` in
1262   // `srcStrides` and source dim sizes.
1263   size_t result = 0;
1264   int rankDiff = srcType.getRank() - vectorType.getRank();
1265   for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1266     // Check that the inner dim size is 1 for both memref type and vector slice.
1267     // It can be folded only if they are 1 and the stride is 1.
1268     int dim = vectorType.getRank() - i - 1;
1269     if (srcStrides[dim + rankDiff] != 1 ||
1270         srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1271       break;
1272     result++;
1273   }
1274   return result;
1275 }
1276 
1277 /// Drop inner most contiguous unit dimensions from transfer_read operand.
1278 class DropInnerMostUnitDimsTransferRead
1279     : public OpRewritePattern<vector::TransferReadOp> {
1280   using OpRewritePattern::OpRewritePattern;
1281 
1282   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1283                                 PatternRewriter &rewriter) const override {
1284     // TODO: support 0-d corner case.
1285     if (readOp.getTransferRank() == 0)
1286       return failure();
1287 
1288     // TODO: support mask.
1289     if (readOp.getMask())
1290       return failure();
1291 
1292     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1293     if (!srcType)
1294       return failure();
1295 
1296     if (!readOp.getPermutationMap().isMinorIdentity())
1297       return failure();
1298 
1299     auto targetType = readOp.getVectorType();
1300     if (targetType.getRank() <= 1)
1301       return failure();
1302 
1303     FailureOr<size_t> maybeDimsToDrop =
1304         getTransferFoldableInnerUnitDims(srcType, targetType);
1305     if (failed(maybeDimsToDrop))
1306       return failure();
1307 
1308     size_t dimsToDrop = maybeDimsToDrop.value();
1309     if (dimsToDrop == 0)
1310       return failure();
1311 
1312     auto inBounds = readOp.getInBoundsValues();
1313     auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1314     if (llvm::is_contained(droppedInBounds, false))
1315       return failure();
1316 
1317     auto resultTargetVecType =
1318         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1319                         targetType.getElementType(),
1320                         targetType.getScalableDims().drop_back(dimsToDrop));
1321 
1322     auto loc = readOp.getLoc();
1323     SmallVector<OpFoldResult> sizes =
1324         memref::getMixedSizes(rewriter, loc, readOp.getSource());
1325     SmallVector<OpFoldResult> offsets(srcType.getRank(),
1326                                       rewriter.getIndexAttr(0));
1327     SmallVector<OpFoldResult> strides(srcType.getRank(),
1328                                       rewriter.getIndexAttr(1));
1329     auto resultMemrefType =
1330         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1331             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1332             strides));
1333     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1334         readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1335     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1336         loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1337     auto permMap = getTransferMinorIdentityMap(
1338         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1339     Value result = rewriter.create<vector::TransferReadOp>(
1340         loc, resultTargetVecType, rankedReducedView,
1341         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1342         readOp.getPadding(),
1343         // TODO: support mask.
1344         /*mask=*/Value(), inBoundsAttr);
1345     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1346                                                      result);
1347     return success();
1348   }
1349 };
1350 
1351 /// Drop inner most contiguous unit dimensions from transfer_write operand.
1352 /// E.g.,
1353 ///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1354 ///      {in_bounds = [true, true, true, true, true]}
1355 ///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1356 ///
1357 /// will be replaced with
1358 ///
1359 ///    %subview = memref.subview %arg0
1360 ///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1361 ///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1362 ///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1363 ///      to vector<1x16x16xf32>
1364 ///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1365 ///      {in_bounds = [true, true, true]}
1366 ///      : vector<1x16x16xf32>, memref<1x512x16xf32>
1367 ///
1368 /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1369 class DropInnerMostUnitDimsTransferWrite
1370     : public OpRewritePattern<vector::TransferWriteOp> {
1371   using OpRewritePattern::OpRewritePattern;
1372 
1373   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1374                                 PatternRewriter &rewriter) const override {
1375     // TODO: support 0-d corner case.
1376     if (writeOp.getTransferRank() == 0)
1377       return failure();
1378 
1379     // TODO: support mask.
1380     if (writeOp.getMask())
1381       return failure();
1382 
1383     auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1384     if (!srcType)
1385       return failure();
1386 
1387     if (!writeOp.getPermutationMap().isMinorIdentity())
1388       return failure();
1389 
1390     auto targetType = writeOp.getVectorType();
1391     if (targetType.getRank() <= 1)
1392       return failure();
1393 
1394     FailureOr<size_t> maybeDimsToDrop =
1395         getTransferFoldableInnerUnitDims(srcType, targetType);
1396     if (failed(maybeDimsToDrop))
1397       return failure();
1398 
1399     size_t dimsToDrop = maybeDimsToDrop.value();
1400     if (dimsToDrop == 0)
1401       return failure();
1402 
1403     auto inBounds = writeOp.getInBoundsValues();
1404     auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1405     if (llvm::is_contained(droppedInBounds, false))
1406       return failure();
1407 
1408     auto resultTargetVecType =
1409         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1410                         targetType.getElementType(),
1411                         targetType.getScalableDims().drop_back(dimsToDrop));
1412 
1413     Location loc = writeOp.getLoc();
1414     SmallVector<OpFoldResult> sizes =
1415         memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1416     SmallVector<OpFoldResult> offsets(srcType.getRank(),
1417                                       rewriter.getIndexAttr(0));
1418     SmallVector<OpFoldResult> strides(srcType.getRank(),
1419                                       rewriter.getIndexAttr(1));
1420     auto resultMemrefType =
1421         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1422             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1423             strides));
1424     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1425         writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1426 
1427     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1428         loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1429     auto permMap = getTransferMinorIdentityMap(
1430         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1431 
1432     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1433         loc, resultTargetVecType, writeOp.getVector());
1434     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1435         writeOp, shapeCast, rankedReducedView,
1436         writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1437         // TODO: support mask.
1438         /*mask=*/Value(), inBoundsAttr);
1439     return success();
1440   }
1441 };
1442 
1443 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1444 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1445 /// with the RHS transposed) lowering.
1446 struct CanonicalizeContractMatmulToMMT final
1447     : OpRewritePattern<vector::ContractionOp> {
1448   using OpRewritePattern::OpRewritePattern;
1449 
1450   using FilterConstraintType =
1451       std::function<LogicalResult(vector::ContractionOp op)>;
1452 
1453   CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1454                                   FilterConstraintType constraint)
1455       : OpRewritePattern<vector::ContractionOp>(context, benefit),
1456         filter(std::move(constraint)) {}
1457 
1458   LogicalResult matchAndRewrite(vector::ContractionOp op,
1459                                 PatternRewriter &rewriter) const override {
1460     if (failed(filter(op)))
1461       return failure();
1462 
1463     Location loc = op.getLoc();
1464     Value lhs = op.getLhs();
1465     Value rhs = op.getRhs();
1466     Value res = op.getAcc();
1467 
1468     // Set up the parallel/reduction structure in right form.
1469     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1470     auto infer = [&](MapList m) {
1471       return AffineMap::inferFromExprList(m, op.getContext());
1472     };
1473     AffineExpr m;
1474     AffineExpr n;
1475     AffineExpr k;
1476     bindDims(rewriter.getContext(), m, n, k);
1477     static constexpr std::array<int64_t, 2> perm = {1, 0};
1478     auto iteratorTypes = op.getIteratorTypes().getValue();
1479     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1480     if (iteratorTypes.size() != 3 ||
1481         !vector::isParallelIterator(iteratorTypes[0]) ||
1482         !vector::isParallelIterator(iteratorTypes[1]) ||
1483         !vector::isReductionIterator(iteratorTypes[2]))
1484       return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1485 
1486     // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1487     const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1488     if (maps == canonicalForm)
1489       return rewriter.notifyMatchFailure(op, "already in the canonical form");
1490 
1491     // Create a vector transpose making sure to emit zero/sign-extend at the
1492     // end.
1493     auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1494       if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1495         Value trans =
1496             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1497         VectorType newType =
1498             cast<VectorType>(trans.getType())
1499                 .clone(cast<VectorType>(mat.getType()).getElementType());
1500         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1501       }
1502       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1503         Value trans =
1504             rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1505         VectorType newType =
1506             VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1507                             cast<VectorType>(mat.getType()).getElementType());
1508         return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1509       }
1510       return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1511     };
1512 
1513     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1514       rhs = createTranspose(rhs);
1515     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1516       lhs = createTranspose(lhs);
1517     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1518       rhs = createTranspose(rhs);
1519       lhs = createTranspose(lhs);
1520     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1521       std::swap(rhs, lhs);
1522       rhs = createTranspose(rhs);
1523       lhs = createTranspose(lhs);
1524     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1525       std::swap(rhs, lhs);
1526       rhs = createTranspose(rhs);
1527     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1528       std::swap(lhs, rhs);
1529       lhs = createTranspose(lhs);
1530     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1531       std::swap(lhs, rhs);
1532     } else {
1533       return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1534     }
1535     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1536         op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1537         op.getIteratorTypes());
1538     return success();
1539   };
1540 
1541 private:
1542   FilterConstraintType filter;
1543 };
1544 
1545 /// Pattern to fold arithmetic extensions on floating point data types into
1546 /// vector contraction operations. linalg.matmul introduces arithmetic
1547 /// extensions on its operands. Please mlir snippets below for more details.
1548 /// ```mlir
1549 ///   "linalg.matmul"(%lhs, %rhs, %acc) ({
1550 ///      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1551 ///        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1552 ///        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1553 ///        %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1554 ///        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1555 ///        "linalg.yield"(%acc) : (f32) -> ()
1556 ///     })
1557 /// ```
1558 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1559 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1560 /// This pattern folds the arithmetic extensions into the vector contraction and
1561 /// enables the usage of native mixed precision Tensor Core instructions.
1562 template <typename ExtOp>
1563 struct FoldArithExtIntoContractionOp
1564     : public OpRewritePattern<vector::ContractionOp> {
1565   using OpRewritePattern::OpRewritePattern;
1566 
1567   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1568                                 PatternRewriter &rewriter) const override {
1569 
1570     auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1571     auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1572 
1573     if (!lhsDefOp || !rhsDefOp) {
1574       return rewriter.notifyMatchFailure(contractOp,
1575                                          "no defining op on contract operands");
1576     }
1577 
1578     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1579         contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1580         contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1581         contractOp.getIteratorTypesAttr());
1582 
1583     return success();
1584   }
1585 };
1586 
1587 /// Pattern to fold chained reduction to a series of vector additions and a
1588 /// final reduction. This form should require fewer subgroup operations.
1589 ///
1590 /// ```mlir
1591 /// %a = vector.reduction <add> %x, %acc
1592 /// %b = vector.reduction <add> %y, %a
1593 ///  ==>
1594 /// %a = arith.addf %x, %y
1595 /// %b = vector.reduction <add> %a, %acc
1596 /// ```
1597 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1598   using OpRewritePattern::OpRewritePattern;
1599 
1600   LogicalResult matchAndRewrite(vector::ReductionOp op,
1601                                 PatternRewriter &rewriter) const override {
1602     // TODO: Handle other combining kinds.
1603     if (op.getKind() != vector::CombiningKind::ADD)
1604       return failure();
1605 
1606     // Accumulator is optional.
1607     Value acc = op.getAcc();
1608     if (!acc)
1609       return failure();
1610 
1611     if (!acc.getType().isIntOrFloat())
1612       return failure();
1613 
1614     auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1615     if (!parentReduction)
1616       return failure();
1617 
1618     Location loc = op.getLoc();
1619     Value vAdd;
1620     if (isa<IntegerType>(acc.getType())) {
1621       vAdd = rewriter.createOrFold<arith::AddIOp>(
1622           loc, parentReduction.getVector(), op.getVector());
1623     } else {
1624       vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1625                                             op.getVector());
1626     }
1627     rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1628                                                      parentReduction.getAcc());
1629     return success();
1630   }
1631 };
1632 
1633 // Helper function dropping unit non-scalable dimension from a VectorType
1634 // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1635 // dimensions are not dropped. Folding such dimensions would require "shifting"
1636 // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1637 // vector<[4]xf32>). This could be implemented in the future.
1638 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1639   auto inVecShape = inVecTy.getShape();
1640   SmallVector<int64_t> newShape;
1641   SmallVector<bool> newScalableDims;
1642   for (auto [dim, isScalable] :
1643        llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1644     if (dim == 1 && !isScalable)
1645       continue;
1646 
1647     newShape.push_back(dim);
1648     newScalableDims.push_back(isScalable);
1649   }
1650   // All dims have been dropped, return vector<1xeType>.
1651   if (newShape.empty()) {
1652     newShape.push_back(1);
1653     newScalableDims.push_back(false);
1654   }
1655 
1656   return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1657 }
1658 
1659 /// For vectors with at least one unit dim, replaces:
1660 ///   elementwise(a, b)
1661 /// with:
1662 ///   sc_a = shape_cast(a)
1663 ///   sc_b = shape_cast(b)
1664 ///   res = elementwise(sc_a, sc_b)
1665 ///   return shape_cast(res)
1666 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1667 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1668 /// required to be rank > 1.
1669 ///
1670 /// Ex:
1671 ///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1672 ///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1673 ///
1674 /// gets converted to:
1675 ///
1676 ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1677 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1678 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1679 ///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1680 ///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1681 ///
1682 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1683 /// `%cast`.
1684 struct DropUnitDimFromElementwiseOps final
1685     : public OpTraitRewritePattern<OpTrait::Elementwise> {
1686   using OpTraitRewritePattern::OpTraitRewritePattern;
1687   LogicalResult matchAndRewrite(Operation *op,
1688                                 PatternRewriter &rewriter) const override {
1689     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1690       return failure();
1691 
1692     auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1693     if (!resultVectorType)
1694       return failure();
1695 
1696     // Check the operand pre-conditions. For `Elementwise` ops all operands are
1697     // guaranteed to have identical shapes (with some exceptions such as
1698     // `arith.select`) and it suffices to only check one of them.
1699     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1700     if (!sourceVectorType)
1701       return failure();
1702     if (sourceVectorType.getRank() < 2)
1703       return failure();
1704 
1705     SmallVector<Value> newOperands;
1706     auto loc = op->getLoc();
1707     for (auto operand : op->getOperands()) {
1708       auto opVectorType = cast<VectorType>(operand.getType());
1709       auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1710       if (newVType == opVectorType)
1711         return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1712 
1713       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1714       newOperands.push_back(opSC);
1715     }
1716 
1717     VectorType newResultVectorType =
1718         dropNonScalableUnitDimFromType(resultVectorType);
1719     // Create an updated elementwise Op without unit dim.
1720     Operation *elementwiseOp =
1721         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1722                         newResultVectorType, op->getAttrs());
1723 
1724     // Restore the unit dim by applying vector.shape_cast to the result.
1725     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1726                                              elementwiseOp->getResult(0));
1727 
1728     return success();
1729   }
1730 };
1731 
1732 /// A pattern to drop unit dims from vector.transpose.
1733 ///
1734 /// Example:
1735 ///
1736 ///  BEFORE:
1737 ///  ```mlir
1738 ///  %transpose = vector.transpose %vector, [3, 0, 1, 2]
1739 ///    : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1740 ///  ```
1741 ///
1742 ///  AFTER:
1743 ///  ```mlir
1744 ///  %dropDims = vector.shape_cast %vector
1745 ///    : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1746 ///  %transpose = vector.transpose %0, [1, 0]
1747 ///    : vector<4x[4]xf32> to vector<[4]x4xf32>
1748 ///  %restoreDims = vector.shape_cast %transpose
1749 ///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1750 ///  ```
1751 struct DropUnitDimsFromTransposeOp final
1752     : OpRewritePattern<vector::TransposeOp> {
1753   using OpRewritePattern::OpRewritePattern;
1754 
1755   LogicalResult matchAndRewrite(vector::TransposeOp op,
1756                                 PatternRewriter &rewriter) const override {
1757     VectorType sourceType = op.getSourceVectorType();
1758     VectorType sourceTypeWithoutUnitDims =
1759         dropNonScalableUnitDimFromType(sourceType);
1760 
1761     if (sourceType == sourceTypeWithoutUnitDims)
1762       return failure();
1763 
1764     // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1765     auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
1766     SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1767     int64_t droppedDims = 0;
1768     for (auto [i, dim] : llvm::enumerate(sourceDims)) {
1769       droppedDimsBefore[i] = droppedDims;
1770       if (dim == std::make_tuple(1, false))
1771         ++droppedDims;
1772     }
1773 
1774     // Drop unit dims from transpose permutation.
1775     ArrayRef<int64_t> perm = op.getPermutation();
1776     SmallVector<int64_t> newPerm;
1777     for (int64_t idx : perm) {
1778       if (sourceDims[idx] == std::make_tuple(1, false))
1779         continue;
1780       newPerm.push_back(idx - droppedDimsBefore[idx]);
1781     }
1782 
1783     // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1784     // type when the dimensions are unit dimensions. In this case, the newPerm
1785     // should be [0].
1786     if (newPerm.empty()) {
1787       newPerm.push_back(0);
1788     }
1789 
1790     Location loc = op.getLoc();
1791     // Drop the unit dims via shape_cast.
1792     auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
1793         loc, sourceTypeWithoutUnitDims, op.getVector());
1794     // Create the new transpose.
1795     auto transposeWithoutUnitDims =
1796         rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1797     // Restore the unit dims via shape cast.
1798     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
1799         op, op.getResultVectorType(), transposeWithoutUnitDims);
1800 
1801     return success();
1802   }
1803 };
1804 
1805 /// A pattern to drop unit dims from the iter_args of an scf.for.
1806 ///
1807 /// Example:
1808 ///
1809 ///  BEFORE:
1810 ///  ```mlir
1811 ///  %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
1812 ///    ...
1813 ///    scf.yield %
1814 ///  }
1815 ///  ```
1816 ///
1817 ///  AFTER:
1818 ///  ```mlir
1819 ///  %drop = vector.shape_cast %init
1820 ///    : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
1821 ///  %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
1822 ///    %new_iter = vector.shape_cast %iter
1823 ///      : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1824 ///    ...
1825 ///  }
1826 ///  %res = vector.shape_cast %new_loop
1827 ///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1828 ///  ```
1829 struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
1830   using OpRewritePattern::OpRewritePattern;
1831 
1832   LogicalResult matchAndRewrite(scf::ForOp forOp,
1833                                 PatternRewriter &rewriter) const override {
1834     /// Find the first iter_arg with droppable unit dims. Further applications
1835     /// of this pattern will apply to later arguments.
1836     for (OpOperand &operand : forOp.getInitArgsMutable()) {
1837       auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1838       if (!vectorType)
1839         continue;
1840 
1841       VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1842       if (vectorType == newVectorType)
1843         continue;
1844 
1845       // Create a new ForOp with that iter operand replaced.
1846       auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
1847         return b.create<vector::ShapeCastOp>(loc, type, source);
1848       };
1849 
1850       Value replacement =
1851           castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1852       rewriter.replaceOp(forOp,
1853                          replaceAndCastForOpIterArg(rewriter, forOp, operand,
1854                                                     replacement, castFn));
1855       return success();
1856     }
1857     return failure();
1858   }
1859 };
1860 
1861 /// Pattern to eliminate redundant zero-constants added to reduction operands.
1862 /// It's enough for there to be one initial zero value, so we can eliminate the
1863 /// extra ones that feed into `vector.reduction <add>`. These get created by the
1864 /// `ChainedReduction` pattern.
1865 ///
1866 /// ```mlir
1867 /// %a = arith.addf %x, %zero
1868 /// %b = arith.addf %a, %y
1869 /// %c = vector.reduction <add> %b, %acc
1870 ///  ==>
1871 /// %b = arith.addf %a, %y
1872 /// %c = vector.reduction <add> %b, %acc
1873 /// ```
1874 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1875   using OpRewritePattern::OpRewritePattern;
1876 
1877   LogicalResult matchAndRewrite(vector::ReductionOp op,
1878                                 PatternRewriter &rewriter) const override {
1879     // TODO: Handle other reduction kinds and their identity values.
1880     if (op.getKind() != vector::CombiningKind::ADD)
1881       return failure();
1882 
1883     Type elemType = op.getSourceVectorType().getElementType();
1884     // The integer case should be handled by `arith.addi` folders, only check
1885     // for floats here.
1886     if (!isa<FloatType>(elemType))
1887       return failure();
1888 
1889     auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1890     if (!vAdd)
1891       return failure();
1892     auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1893     if (!addLhs)
1894       return failure();
1895 
1896     if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1897       return failure();
1898 
1899     auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1900                                                  vAdd.getRhs());
1901     rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1902                                                      op.getAcc());
1903     return success();
1904   }
1905 };
1906 
1907 /// Example:
1908 /// ```
1909 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
1910 /// ```
1911 /// is transformed into:
1912 /// ```
1913 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
1914 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
1915 /// %a = arith.addf %y, %z : f32
1916 /// ```
1917 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1918   BreakDownVectorReduction(MLIRContext *context,
1919                            unsigned maxNumElementsToExtract,
1920                            PatternBenefit benefit)
1921       : OpRewritePattern(context, benefit),
1922         maxNumElementsToExtract(maxNumElementsToExtract) {}
1923 
1924   LogicalResult matchAndRewrite(vector::ReductionOp op,
1925                                 PatternRewriter &rewriter) const override {
1926     VectorType type = op.getSourceVectorType();
1927     if (type.isScalable() || op.isMasked())
1928       return failure();
1929     assert(type.getRank() == 1 && "Expected a 1-d vector");
1930 
1931     int64_t numElems = type.getNumElements();
1932     if (numElems > maxNumElementsToExtract) {
1933       return rewriter.notifyMatchFailure(
1934           op, llvm::formatv("has too many vector elements ({0}) to break down "
1935                             "(max allowed: {1})",
1936                             numElems, maxNumElementsToExtract));
1937     }
1938 
1939     Location loc = op.getLoc();
1940     SmallVector<Value> extracted(numElems, nullptr);
1941     for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1942       extractedElem = rewriter.create<vector::ExtractOp>(
1943           loc, op.getVector(), static_cast<int64_t>(idx));
1944 
1945     Value res = extracted.front();
1946     for (auto extractedElem : llvm::drop_begin(extracted))
1947       res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1948                                        extractedElem, op.getFastmathAttr());
1949     if (Value acc = op.getAcc())
1950       res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1951                                        op.getFastmathAttr());
1952 
1953     rewriter.replaceOp(op, res);
1954     return success();
1955   }
1956 
1957 private:
1958   unsigned maxNumElementsToExtract = 0;
1959 };
1960 
1961 /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
1962 /// B)`.
1963 /// Example:
1964 ///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1965 ///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1966 ///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1967 ///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1968 ///
1969 /// Becomes :
1970 ///
1971 ///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1972 ///
1973 /// Supports only 1D-to-2D broadcasts. The following cases are not supported.
1974 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1975 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1976 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1977 template <typename MulOpType>
1978 struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
1979   using OpRewritePattern<MulOpType>::OpRewritePattern;
1980   // Returns whether a vector.broadcast matches requirements for an outerproduct
1981   // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
1982   bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
1983     // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
1984     // shape_casts/broadcasts which does not belong in this pattern.
1985     if (!broadcastOp.computeBroadcastedUnitDims().empty())
1986       return false;
1987     // Avoid broadcast like f32 or vector<f32> -> ResType
1988     auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1989     return srcType && srcType.getRank() != 2;
1990   }
1991 
1992   LogicalResult matchAndRewrite(MulOpType mulOp,
1993                                 PatternRewriter &rewriter) const override {
1994     auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1995     if (!resType)
1996       return failure();
1997     if (resType.getRank() != 2)
1998       return failure();
1999     /// If operandA can be written as tr(broadcast(A)) and operandB as
2000     /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2001     /// vector.outerproduct(A, B). Returns failure() otherwise.
2002     auto matchOuterProduct =
2003         [&](Value operandA,
2004             Value operandB) -> FailureOr<vector::OuterProductOp> {
2005       auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2006       if (!transposedLhs)
2007         return failure();
2008       // Fail unless this is a true 2-D matrix transpose.
2009       ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2010       if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2011         return failure();
2012 
2013       auto broadcastedLhs =
2014           transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2015       if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2016         return failure();
2017 
2018       auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2019       if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2020         return failure();
2021 
2022       return rewriter.create<vector::OuterProductOp>(
2023           mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2024           broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2025     };
2026 
2027     Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2028     auto maybeOuterP = matchOuterProduct(lhs, rhs);
2029     // Handle commutativity, the transposed op is the outerproduct LHS.
2030     if (failed(maybeOuterP))
2031       maybeOuterP = matchOuterProduct(rhs, lhs);
2032     if (failed(maybeOuterP))
2033       return failure();
2034     rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2035     return success();
2036   }
2037 };
2038 
2039 } // namespace
2040 
2041 void mlir::vector::populateFoldArithExtensionPatterns(
2042     RewritePatternSet &patterns) {
2043   patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2044                FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2045       patterns.getContext());
2046 }
2047 
2048 void mlir::vector::populateVectorMaskMaterializationPatterns(
2049     RewritePatternSet &patterns, bool force32BitVectorIndices,
2050     PatternBenefit benefit) {
2051   patterns.add<VectorCreateMaskOpConversion,
2052                MaterializeTransferMask<vector::TransferReadOp>,
2053                MaterializeTransferMask<vector::TransferWriteOp>>(
2054       patterns.getContext(), force32BitVectorIndices, benefit);
2055   patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2056 }
2057 
2058 void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
2059                                                     PatternBenefit benefit) {
2060   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
2061 }
2062 
2063 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2064     RewritePatternSet &patterns, PatternBenefit benefit) {
2065   // TODO: Consider either:
2066   //  * including DropInnerMostUnitDimsTransferRead and
2067   //    DropInnerMostUnitDimsTransferWrite, or
2068   //  * better naming to distinguish this and
2069   //    populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
2070   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2071                DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2072       patterns.getContext(), benefit);
2073 }
2074 
2075 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2076     RewritePatternSet &patterns, PatternBenefit benefit) {
2077   patterns.add<BubbleDownVectorBitCastForExtract,
2078                BubbleDownBitCastForStridedSliceExtract,
2079                BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2080       patterns.getContext(), benefit);
2081 }
2082 
2083 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2084     RewritePatternSet &patterns,
2085     std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2086   patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2087                                        std::move(controlFn), benefit);
2088 }
2089 
2090 void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
2091     RewritePatternSet &patterns,
2092     std::function<LogicalResult(vector::ContractionOp)> constraint,
2093     PatternBenefit benefit) {
2094   patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2095                                                 std::move(constraint));
2096 }
2097 
2098 void mlir::vector::populateVectorReductionToContractPatterns(
2099     RewritePatternSet &patterns, PatternBenefit benefit) {
2100   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2101                CombineContractABTranspose, CombineContractResultTranspose>(
2102       patterns.getContext(), benefit);
2103 }
2104 
2105 void mlir::vector::
2106     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2107         RewritePatternSet &patterns, PatternBenefit benefit) {
2108   patterns.add<DropInnerMostUnitDimsTransferRead,
2109                DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
2110                                                    benefit);
2111 }
2112 
2113 void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
2114                                                  PatternBenefit benefit) {
2115   patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2116                ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
2117                                                  benefit);
2118 }
2119 
2120 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2121     RewritePatternSet &patterns, PatternBenefit benefit) {
2122   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2123   patterns.add<ReduceRedundantZero>(patterns.getContext(),
2124                                     PatternBenefit(benefit.getBenefit() + 1));
2125 }
2126 
2127 void mlir::vector::populateBreakDownVectorReductionPatterns(
2128     RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2129     PatternBenefit benefit) {
2130   patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2131                                          maxNumElementsToExtract, benefit);
2132 }
2133 
2134 void mlir::vector::populateElementwiseToVectorOpsPatterns(
2135     RewritePatternSet &patterns) {
2136   patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2137                FoldArithToVectorOuterProduct<arith::MulIOp>>(
2138       patterns.getContext());
2139 }
2140 
2141 //===----------------------------------------------------------------------===//
2142 // TableGen'd enum attribute definitions
2143 //===----------------------------------------------------------------------===//
2144 
2145 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
2146