xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (revision 2c06fb899966b49ff0fe4adf55fceb7d1941fbca)
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <cassert>
16 #include <cstdint>
17 #include <functional>
18 #include <optional>
19 #include <type_traits>
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arith/IR/Arith.h"
23 #include "mlir/Dialect/Arith/Utils/Utils.h"
24 #include "mlir/Dialect/Linalg/IR/Linalg.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
30 #include "mlir/Dialect/Vector/IR/VectorOps.h"
31 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
32 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
33 #include "mlir/IR/BuiltinAttributeInterfaces.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/ImplicitLocOpBuilder.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Interfaces/VectorInterfaces.h"
41 #include "mlir/Support/LogicalResult.h"
42 
43 #include "llvm/ADT/DenseSet.h"
44 #include "llvm/ADT/MapVector.h"
45 #include "llvm/ADT/STLExtras.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/raw_ostream.h"
50 
51 #define DEBUG_TYPE "vector-to-vector"
52 
53 using namespace mlir;
54 using namespace mlir::vector;
55 
56 template <typename IntType>
57 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
58   return llvm::to_vector<4>(llvm::map_range(
59       arrayAttr.getAsRange<IntegerAttr>(),
60       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
61 }
62 
63 // Helper to find an index in an affine map.
64 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
65   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
66     int64_t idx = map.getDimPosition(i);
67     if (idx == index)
68       return i;
69   }
70   return std::nullopt;
71 }
72 
73 namespace {
74 
75 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
76 //
77 // Example:
78 //
79 //  The following MLIR with cancelling ShapeCastOps:
80 //
81 //   %0 = source : vector<5x4x2xf32>
82 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
83 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
84 //   %3 = user %2 : vector<5x4x2xf32>
85 //
86 //  Should canonicalize to the following:
87 //
88 //   %0 = source : vector<5x4x2xf32>
89 //   %1 = user %0 : vector<5x4x2xf32>
90 //
91 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
92   using OpRewritePattern::OpRewritePattern;
93 
94   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
95                                 PatternRewriter &rewriter) const override {
96     // Check if 'shapeCastOp' has vector source/result type.
97     auto sourceVectorType =
98         dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
99     auto resultVectorType =
100         dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
101     if (!sourceVectorType || !resultVectorType)
102       return failure();
103 
104     // Check if shape cast op source operand is also a shape cast op.
105     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
106         shapeCastOp.getSource().getDefiningOp());
107     if (!sourceShapeCastOp)
108       return failure();
109     auto operandSourceVectorType =
110         cast<VectorType>(sourceShapeCastOp.getSource().getType());
111     auto operandResultVectorType = sourceShapeCastOp.getType();
112 
113     // Check if shape cast operations invert each other.
114     if (operandSourceVectorType != resultVectorType ||
115         operandResultVectorType != sourceVectorType)
116       return failure();
117 
118     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
119     return success();
120   }
121 };
122 
123 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
124 /// Ex:
125 /// ```
126 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
127 ///   %1 = vector.multi_reduction add, %0 [1]
128 ///     : vector<8x32x16xf32> to vector<8x16xf32>
129 /// ```
130 /// Gets converted to:
131 /// ```
132 ///   %1 = vector.contract {indexing_maps = [
133 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
135 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
136 ///    iterator_types = ["parallel", "parallel", "reduction"],
137 ///    kind = add} %0, %arg1, %cst_f0
138 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
139 ///  ```
140 struct MultiReduceToContract
141     : public OpRewritePattern<vector::MultiDimReductionOp> {
142   using OpRewritePattern::OpRewritePattern;
143 
144   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145                                 PatternRewriter &rewriter) const override {
146     if (reduceOp.getKind() != vector::CombiningKind::ADD)
147       return failure();
148     Operation *mulOp = reduceOp.getSource().getDefiningOp();
149     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
150       return failure();
151     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
152     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
153     SmallVector<AffineExpr> exprs;
154     SmallVector<vector::IteratorType> iteratorTypes;
155     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
156       if (!isReduceDim.value()) {
157         iteratorTypes.push_back(vector::IteratorType::parallel);
158         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
159       } else {
160         iteratorTypes.push_back(vector::IteratorType::reduction);
161       }
162     }
163     auto dstMap =
164         AffineMap::get(/*dimCount=*/reductionMask.size(),
165                        /*symbolCount=*/0, exprs, reduceOp.getContext());
166     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
167         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
168         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
169         rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
170             iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
171               return IteratorTypeAttr::get(rewriter.getContext(), t);
172             }))));
173     return success();
174   }
175 };
176 
177 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
178 /// Ex:
179 /// ```
180 ///   %0 = vector.transpose %arg0, [2, 0, 1]
181 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
182 ///   %1 = vector.contract {indexing_maps = [
183 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
184 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
185 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
186 ///    iterator_types = ["parallel", "parallel", "reduction"],
187 ///    kind = add} %0, %arg1, %cst_f0
188 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
189 /// ```
190 /// Gets converted to:
191 /// ```
192 ///   %1 = vector.contract {indexing_maps = [
193 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
194 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
195 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
196 ///    iterator_types = ["parallel", "parallel", "reduction"],
197 ///    kind = add} %arg0, %arg1, %cst_f0
198 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
199 ///  ```
200 struct CombineContractABTranspose final
201     : public OpRewritePattern<vector::ContractionOp> {
202   using OpRewritePattern::OpRewritePattern;
203 
204   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
205                                 PatternRewriter &rewriter) const override {
206     SmallVector<AffineMap> maps =
207         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
208     Value lhs = contractOp.getLhs();
209     Value rhs = contractOp.getRhs();
210     size_t index = 0;
211     bool changed = false;
212     for (Value *operand : {&lhs, &rhs}) {
213       AffineMap &map = maps[index++];
214       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
215       if (!transposeOp)
216         continue;
217       AffineMap permutationMap = AffineMap::getPermutationMap(
218           transposeOp.getPermutation(), contractOp.getContext());
219       map = inversePermutation(permutationMap).compose(map);
220       *operand = transposeOp.getVector();
221       changed = true;
222     }
223     if (!changed)
224       return failure();
225     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
226         contractOp, lhs, rhs, contractOp.getAcc(),
227         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
228     return success();
229   }
230 };
231 
232 /// Merges accumulator and result transposes into contract.
233 ///
234 /// For example:
235 /// ```mlir
236 /// %accT = vector.transpose %acc, [0, 2, 1]
237 ///   : vector<2x8x4xf32> to vector<2x4x8xf32>
238 /// %contract = vector.contract {
239 ///   indexing_maps = [
240 ///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
241 ///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
242 ///     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
243 ///   ],
244 ///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
245 ///   kind = #vector.kind<add>
246 /// } %lhs, %rhs, %accT
247 ///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
248 /// %0 = vector.transpose %contract, [0, 2, 1]
249 ///   : vector<2x4x8xf32> to vector<2x8x4>
250 /// ```
251 /// Becomes:
252 /// ```mlir
253 /// %0 = vector.contract {
254 ///   indexing_maps = [
255 ///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
256 ///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
257 ///     affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
258 ///   ],
259 ///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
260 ///   kind = #vector.kind<add>
261 /// } %lhs, %rhs, %acc
262 ///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
263 /// ```
264 struct CombineContractResultTranspose final
265     : public OpRewritePattern<vector::TransposeOp> {
266   using OpRewritePattern::OpRewritePattern;
267 
268   LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269                                 PatternRewriter &rewriter) const override {
270     auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
271     if (!contractOp || !contractOp->hasOneUse())
272       return failure();
273 
274     auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
275     if (!accTOp)
276       return failure();
277 
278     MLIRContext *context = contractOp.getContext();
279     auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
280     AffineMap contractMap = maps.back();
281 
282     // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
283     // To index into A in contract, we need revert(f)(g(C)) -> A.
284     auto accTMap =
285         AffineMap::getPermutationMap(accTOp.getPermutation(), context);
286 
287     // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
288     // To index into E in contract, we need h(g(C)) -> E.
289     auto resTMap =
290         AffineMap::getPermutationMap(resTOp.getPermutation(), context);
291     auto combinedResMap = resTMap.compose(contractMap);
292 
293     // The accumulator and result share the same indexing map. So they should be
294     // the same to be able to merge. This means combinedResMap is the same as
295     // inversePermutation(accTMap).compose(contractMap), which means
296     if (inversePermutation(accTMap) != resTMap)
297       return failure();
298     maps.back() = combinedResMap;
299 
300     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
301         resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
302         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
303     return success();
304   }
305 };
306 
307 /// Merge BroadcastOp into ContractionOp user.
308 /// Ex:
309 /// ```
310 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
311 ///   %1 = vector.contract {indexing_maps = [
312 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
313 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
314 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
315 ///    iterator_types = ["parallel", "parallel", "reduction"],
316 ///    kind = add} %0, %arg1, %cst_f0
317 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
318 /// ```
319 /// Gets converted to:
320 /// ```
321 ///   %1 = vector.contract {indexing_maps = [
322 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
323 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
324 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
325 ///    iterator_types = ["parallel", "parallel", "reduction"],
326 ///    kind = add} %arg0, %arg1, %cst_f0
327 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
328 ///  ```
329 struct CombineContractBroadcast
330     : public OpRewritePattern<vector::ContractionOp> {
331   using OpRewritePattern::OpRewritePattern;
332 
333   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
334                                 PatternRewriter &rewriter) const override {
335     SmallVector<AffineMap> maps =
336         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
337     Value lhs = contractOp.getLhs();
338     Value rhs = contractOp.getRhs();
339     size_t index = 0;
340     bool changed = false;
341     for (Value *operand : {&lhs, &rhs}) {
342       AffineMap &map = maps[index++];
343       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
344       if (!broadcast)
345         continue;
346       // contractionOp can only take vector as operands.
347       auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
348       if (!srcType ||
349           srcType.getRank() == broadcast.getResultVectorType().getRank())
350         continue;
351       int64_t rankDiff =
352           broadcast.getResultVectorType().getRank() - srcType.getRank();
353       bool innerDimBroadcast = false;
354       SmallVector<AffineExpr> originalDims;
355       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
356         if (dim.value() != broadcast.getResultVectorType().getDimSize(
357                                rankDiff + dim.index())) {
358           innerDimBroadcast = true;
359           break;
360         }
361         originalDims.push_back(
362             rewriter.getAffineDimExpr(dim.index() + rankDiff));
363       }
364       // Contract doesn't support inner dimension broadcast. Once this is
365       // relaxed we can remove this case.
366       if (innerDimBroadcast)
367         continue;
368 
369       // It would be incorrect to fold a broadcast onto a reduction dimension
370       // of non-unit size.
371       bool nonUnitDimReductionBroadcast = false;
372       for (int64_t i = 0; i < rankDiff; ++i) {
373         if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
374             isReductionIterator(contractOp.getIteratorTypes()
375                                     .getValue()[map.getDimPosition(i)])) {
376           nonUnitDimReductionBroadcast = true;
377           break;
378         }
379       }
380       if (nonUnitDimReductionBroadcast)
381         continue;
382 
383       AffineMap broadcastMap =
384           AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
385                          originalDims, contractOp.getContext());
386       map = broadcastMap.compose(map);
387       *operand = broadcast.getSource();
388       changed = true;
389     }
390 
391     if (!changed)
392       return failure();
393 
394     // Determine which dims are usused, now that the maps have been composed
395     // with the broadcast maps.
396     llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
397     // Compress unused dims.
398     for (auto &m : maps)
399       m = compressDims(m, unusedDimsBitVector);
400     // Compute the combined iterators.
401     SmallVector<Attribute> iterators;
402     for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
403       if (!unusedDimsBitVector.test(i))
404         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
405     }
406     // Check that compressing unused dims isn't removing all reduction dimension
407     // pairs. For example, if the vector.contract had only one reduction
408     // iterator and that was a unit-dimension created by a broadcast,
409     // then we should bail here, otherwise we would create a contract without
410     // a reduction dimension pair.
411     bool hasReductionIteratorApplyingOnBothSides = false;
412     for (unsigned i = 0; i < iterators.size(); ++i) {
413       if (!isReductionIterator(iterators[i]))
414         continue;
415       if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
416         hasReductionIteratorApplyingOnBothSides = true;
417         break;
418       }
419     }
420     if (!hasReductionIteratorApplyingOnBothSides)
421       return failure();
422 
423     // If the compressed maps have a dimension that is not used by either LHS or
424     // RHS then the ContractionOp verifier would fail.
425     if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
426       return failure();
427     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
428         contractOp, lhs, rhs, contractOp.getAcc(),
429         rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
430     return success();
431   }
432 };
433 
434 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
435 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
436 /// casting ops are around these operations.
437 /// Ex:
438 /// ```
439 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
440 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
441 /// ```
442 /// Gets converted to:
443 /// ```
444 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
445 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
446 /// ```
447 struct ReorderCastOpsOnBroadcast
448     : public OpInterfaceRewritePattern<CastOpInterface> {
449   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
450 
451   LogicalResult matchAndRewrite(CastOpInterface op,
452                                 PatternRewriter &rewriter) const override {
453     if (op->getNumOperands() != 1)
454       return failure();
455     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
456     if (!bcastOp)
457       return failure();
458 
459     Type castResTy = getElementTypeOrSelf(op->getResult(0));
460     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
461       castResTy = vecTy.clone(castResTy);
462     auto *castOp =
463         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
464                         bcastOp.getSource(), castResTy, op->getAttrs());
465     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
466         op, op->getResult(0).getType(), castOp->getResult(0));
467     return success();
468   }
469 };
470 
471 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
472 /// transpose ops and contraction ops closer, which kicks in
473 /// CombineContractABTranspose pattern when elementwise ops are between these
474 /// operations. Ex:
475 /// ```
476 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
477 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
478 /// %r = arith.addf %at, %bt : vector<2x4xf32>
479 /// ```
480 /// Gets converted to:
481 /// ```
482 /// %0 = arith.addf %a, %b : vector<4x2xf32>
483 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
484 /// ```
485 struct ReorderElementwiseOpsOnTranspose final
486     : public OpTraitRewritePattern<OpTrait::Elementwise> {
487   using OpTraitRewritePattern::OpTraitRewritePattern;
488   LogicalResult matchAndRewrite(Operation *op,
489                                 PatternRewriter &rewriter) const override {
490     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
491       return failure();
492 
493     // Make sure all operands are transpose/constant ops and collect their
494     // transposition maps.
495     SmallVector<ArrayRef<int64_t>> transposeMaps;
496     transposeMaps.reserve(op->getNumOperands());
497     // Record the initial type before transposition. We'll use its shape later.
498     // Any type will do here as we will check all transpose maps are the same.
499     VectorType srcType;
500     for (Value operand : op->getOperands()) {
501       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502       if (transposeOp) {
503         transposeMaps.push_back(transposeOp.getPermutation());
504         srcType = transposeOp.getSourceVectorType();
505       } else if (!matchPattern(operand, m_Constant())) {
506         return failure();
507       }
508     }
509     if (transposeMaps.empty())
510       return failure();
511     // This is an elementwise op, so all transposed operands should have the
512     // same type. We need to additionally check that all transposes uses the
513     // same map.
514     if (!llvm::all_equal(transposeMaps))
515       return rewriter.notifyMatchFailure(op, "different transpose map");
516 
517     SmallVector<Value> srcValues;
518     srcValues.reserve(op->getNumOperands());
519 
520     // If there are constant operands, we need to insert inverse transposes for
521     // them. Calculate the inverse order first.
522     auto order = transposeMaps.front();
523     SmallVector<int64_t> invOrder(order.size());
524     for (int i = 0, e = order.size(); i < e; ++i)
525       invOrder[order[i]] = i;
526 
527     for (Value operand : op->getOperands()) {
528       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529       if (transposeOp) {
530         srcValues.push_back(transposeOp.getVector());
531       } else {
532         // This is a constant. Create a reverse transpose op for it.
533         auto vectorType =
534             srcType.clone(cast<VectorType>(operand.getType()).getElementType());
535         srcValues.push_back(rewriter.create<vector::TransposeOp>(
536             operand.getLoc(), vectorType, operand, invOrder));
537       }
538     }
539 
540     auto vectorType = srcType.clone(
541         cast<VectorType>(op->getResultTypes()[0]).getElementType());
542     Operation *elementwiseOp =
543         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
544                         vectorType, op->getAttrs());
545     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
546         op, op->getResultTypes()[0], elementwiseOp->getResult(0),
547         transposeMaps.front());
548     return success();
549   }
550 };
551 
552 // Returns the values in `arrayAttr` as an integer vector.
553 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
554   return llvm::to_vector<4>(
555       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
556                       [](IntegerAttr attr) { return attr.getInt(); }));
557 }
558 
559 // Shuffles vector.bitcast op after vector.extract op.
560 //
561 // This transforms IR like:
562 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
563 //   %1 = vector.extract %0[3] : f16 from vector<8xf16>
564 // Into:
565 //   %0 = vector.extract %src[1] : f32 from vector<4xf32>
566 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
567 //   %2 = vector.extract %1[1] : f16 from vector<2xf16>
568 struct BubbleDownVectorBitCastForExtract
569     : public OpRewritePattern<vector::ExtractOp> {
570   using OpRewritePattern::OpRewritePattern;
571 
572   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
573                                 PatternRewriter &rewriter) const override {
574     // Only support extracting scalars for now.
575     if (extractOp.getSourceVectorType().getRank() != 1)
576       return failure();
577 
578     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
579     if (!castOp)
580       return failure();
581 
582     VectorType castSrcType = castOp.getSourceVectorType();
583     VectorType castDstType = castOp.getResultVectorType();
584     assert(castSrcType.getRank() == castDstType.getRank());
585 
586     // Fail to match if we only have one element in the cast op source.
587     // This is to avoid infinite loop given that this pattern can generate
588     // such cases.
589     if (castSrcType.getNumElements() == 1)
590       return failure();
591 
592     // Only support casting to a larger number of elements or now.
593     // E.g., vector<4xf32> -> vector<8xf16>.
594     if (castSrcType.getNumElements() > castDstType.getNumElements())
595       return failure();
596 
597     unsigned expandRatio =
598         castDstType.getNumElements() / castSrcType.getNumElements();
599 
600     auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
601       assert(values[0].is<Attribute>() && "Unexpected non-constant index");
602       return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
603     };
604 
605     uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
606 
607     // Get the single scalar (as a vector) in the source value that packs the
608     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
609     Location loc = extractOp.getLoc();
610     Value packedValue = rewriter.create<vector::ExtractOp>(
611         loc, castOp.getSource(), index / expandRatio);
612     Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
613     Value zero = rewriter.create<arith::ConstantOp>(
614         loc, packedVecType, rewriter.getZeroAttr(packedVecType));
615     packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
616                                                     /*position=*/0);
617 
618     // Cast it to a vector with the desired scalar's type.
619     // E.g. f32 -> vector<2xf16>
620     VectorType packedType =
621         VectorType::get({expandRatio}, castDstType.getElementType());
622     Value castedValue =
623         rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
624 
625     // Finally extract the desired scalar.
626     rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
627                                                    index % expandRatio);
628     return success();
629   }
630 };
631 
632 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
633 //
634 // This transforms IR like:
635 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
636 //     %0 = vector.extract_strided_slice %cast {
637 //            offsets = [4], sizes = [4], strides = [1]
638 //          } : vector<8xf16> to vector<4xf16>
639 // Into:
640 //   %0 = vector.extract_strided_slice %src {
641 //          offsets = [2], sizes = [2], strides = [1]
642 //        } : vector<4xf32> to vector<2xf32>
643 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
644 struct BubbleDownBitCastForStridedSliceExtract
645     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
646   using OpRewritePattern::OpRewritePattern;
647 
648   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649                                 PatternRewriter &rewriter) const override {
650     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
651     if (!castOp)
652       return failure();
653 
654     VectorType castSrcType = castOp.getSourceVectorType();
655     VectorType castDstType = castOp.getResultVectorType();
656     assert(castSrcType.getRank() == castDstType.getRank());
657 
658     int64_t castSrcLastDim = castSrcType.getShape().back();
659     int64_t castDstLastDim = castDstType.getShape().back();
660     // Require casting to more elements for now; other cases to be implemented.
661     if (castSrcLastDim > castDstLastDim)
662       return failure();
663 
664     // Only accept all one strides for now.
665     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666                      [](const APInt &val) { return !val.isOne(); }))
667       return failure();
668 
669     unsigned rank = extractOp.getSourceVectorType().getRank();
670     assert(castDstLastDim % castSrcLastDim == 0);
671     int64_t expandRatio = castDstLastDim / castSrcLastDim;
672 
673     // If we have a less number of offsets than the rank, then implicitly we
674     // are selecting the full range for the last bitcasted dimension; other
675     // dimensions aren't affected. Otherwise, we need to scale down the last
676     // dimension's offset given we are extracting from less elements now.
677     ArrayAttr newOffsets = extractOp.getOffsets();
678     if (newOffsets.size() == rank) {
679       SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
680       if (offsets.back() % expandRatio != 0)
681         return failure();
682       offsets.back() = offsets.back() / expandRatio;
683       newOffsets = rewriter.getI64ArrayAttr(offsets);
684     }
685 
686     // Similarly for sizes.
687     ArrayAttr newSizes = extractOp.getSizes();
688     if (newSizes.size() == rank) {
689       SmallVector<int64_t> sizes = getIntValueVector(newSizes);
690       if (sizes.back() % expandRatio != 0)
691         return failure();
692       sizes.back() = sizes.back() / expandRatio;
693       newSizes = rewriter.getI64ArrayAttr(sizes);
694     }
695 
696     SmallVector<int64_t> dims =
697         llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698     dims.back() = dims.back() / expandRatio;
699     VectorType newExtractType =
700         VectorType::get(dims, castSrcType.getElementType());
701 
702     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
703         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
704         newSizes, extractOp.getStrides());
705 
706     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
707         extractOp, extractOp.getType(), newExtractOp);
708 
709     return success();
710   }
711 };
712 
713 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
714 //
715 // This transforms IR like:
716 //   %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
717 //   %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
718 // Into:
719 //   %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
720 //   %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
721 //   %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
722 //
723 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
724   using OpRewritePattern::OpRewritePattern;
725 
726   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727                                 PatternRewriter &rewriter) const override {
728     VectorType castSrcType = bitcastOp.getSourceVectorType();
729     VectorType castDstType = bitcastOp.getResultVectorType();
730 
731     // 0-D and scalable vectors are not supported yet.
732     if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733         castDstType.isScalable())
734       return failure();
735 
736     int64_t castSrcLastDim = castSrcType.getShape().back();
737     int64_t castDstLastDim = castDstType.getShape().back();
738     bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739     int64_t ratio;
740     if (isNumElemsShrink) {
741       assert(castSrcLastDim % castDstLastDim == 0);
742       ratio = castSrcLastDim / castDstLastDim;
743     } else {
744       assert(castDstLastDim % castSrcLastDim == 0);
745       ratio = castDstLastDim / castSrcLastDim;
746     }
747 
748     auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
749     if (!insertOp)
750       return failure();
751 
752     // Only vector sources are supported for now.
753     auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
754     if (!insertSrcType)
755       return failure();
756 
757     // Bitcast the source.
758     SmallVector<int64_t> srcDims(insertSrcType.getShape());
759     srcDims.back() =
760         isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761     VectorType newCastSrcType =
762         VectorType::get(srcDims, castDstType.getElementType());
763     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
764         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
765 
766     SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
767     dstDims.back() =
768         isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
769     VectorType newCastDstType =
770         VectorType::get(dstDims, castDstType.getElementType());
771 
772     // Bitcast the destination.
773     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
774         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
775 
776     // Generate new insert.
777     rewriter.replaceOpWithNewOp<vector::InsertOp>(
778         bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
779     return success();
780   }
781 };
782 
783 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
784 //
785 // This transforms IR like:
786 //   %0 = vector.insert_strided_slice %src, %dst {
787 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
788 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
789 // Into:
790 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
791 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
792 //   %2 = vector.insert_strided_slice %src, %dst {
793 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
794 struct BubbleUpBitCastForStridedSliceInsert
795     : public OpRewritePattern<vector::BitCastOp> {
796   using OpRewritePattern::OpRewritePattern;
797 
798   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
799                                 PatternRewriter &rewriter) const override {
800     VectorType castSrcType = bitcastOp.getSourceVectorType();
801     VectorType castDstType = bitcastOp.getResultVectorType();
802     assert(castSrcType.getRank() == castDstType.getRank());
803     // Skip 0-D vector which will not from InsertStridedSliceOp.
804     if (castSrcType.getRank() == 0)
805       return failure();
806 
807     int64_t castSrcLastDim = castSrcType.getShape().back();
808     int64_t castDstLastDim = castDstType.getShape().back();
809     // Require casting to less elements for now; other cases to be implemented.
810     if (castSrcLastDim < castDstLastDim)
811       return failure();
812 
813     assert(castSrcLastDim % castDstLastDim == 0);
814     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
815 
816     auto insertOp =
817         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
818     if (!insertOp)
819       return failure();
820 
821     // Only accept all one strides for now.
822     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
823                      [](const APInt &val) { return !val.isOne(); }))
824       return failure();
825 
826     unsigned rank = insertOp.getSourceVectorType().getRank();
827     // Require insert op to have the same rank for the source and destination
828     // vector; other cases to be implemented.
829     if (rank != insertOp.getDestVectorType().getRank())
830       return failure();
831 
832     // Requires that shape of insert op src is castable to dstType.
833     unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
834     unsigned destinationWidth =
835         castDstType.getElementType().getIntOrFloatBitWidth();
836     unsigned numElements = destinationWidth / sourceWidth;
837     if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
838       return failure();
839 
840     ArrayAttr newOffsets = insertOp.getOffsets();
841     assert(newOffsets.size() == rank);
842     SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
843     if (offsets.back() % shrinkRatio != 0)
844       return failure();
845     offsets.back() = offsets.back() / shrinkRatio;
846     newOffsets = rewriter.getI64ArrayAttr(offsets);
847 
848     SmallVector<int64_t> srcDims =
849         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
850     srcDims.back() = srcDims.back() / shrinkRatio;
851     VectorType newCastSrcType =
852         VectorType::get(srcDims, castDstType.getElementType());
853 
854     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
855         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
856 
857     SmallVector<int64_t> dstDims =
858         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
859     dstDims.back() = dstDims.back() / shrinkRatio;
860     VectorType newCastDstType =
861         VectorType::get(dstDims, castDstType.getElementType());
862 
863     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
864         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
865 
866     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
867         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
868         insertOp.getStrides());
869 
870     return success();
871   }
872 };
873 
874 // Breaks down vector.bitcast op
875 //
876 // This transforms IR like:
877 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
878 // Into:
879 //   %cst = vector.splat %c0_f32 : vector<4xf32>
880 //   %1 = vector.extract_strided_slice %0 {
881 //          offsets = [0], sizes = [4], strides = [1]
882 //        } : vector<8xf16> to vector<4xf16>
883 //   %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
884 //   %4 = vector.insert_strided_slice %2, %cst {
885 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
886 //   %5 = vector.extract_strided_slice %0 {
887 //          offsets = [4], sizes = [4], strides = [1]
888 //        } : vector<8xf16> to vector<4xf16>
889 //   %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
890 //   %7 = vector.insert_strided_slice %6, %cst {
891 //          offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
892 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
893   using OpRewritePattern::OpRewritePattern;
894 
895 public:
896   BreakDownVectorBitCast(MLIRContext *context,
897                          std::function<bool(vector::BitCastOp)> controlFn,
898                          PatternBenefit benefit)
899       : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
900 
901   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
902                                 PatternRewriter &rewriter) const override {
903 
904     if (controlFn && !controlFn(bitcastOp))
905       return failure();
906 
907     VectorType castSrcType = bitcastOp.getSourceVectorType();
908     VectorType castDstType = bitcastOp.getResultVectorType();
909     assert(castSrcType.getRank() == castDstType.getRank());
910 
911     // Only support rank 1 case for now.
912     if (castSrcType.getRank() != 1)
913       return failure();
914 
915     int64_t castSrcLastDim = castSrcType.getShape().back();
916     int64_t castDstLastDim = castDstType.getShape().back();
917     // Require casting to less elements for now; other cases to be implemented.
918     if (castSrcLastDim < castDstLastDim)
919       return failure();
920 
921     assert(castSrcLastDim % castDstLastDim == 0);
922     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
923     // Nothing to do if it is already bitcasting to a single element.
924     if (castSrcLastDim == shrinkRatio)
925       return failure();
926 
927     Location loc = bitcastOp.getLoc();
928     Type elemType = castDstType.getElementType();
929     assert(elemType.isSignlessIntOrIndexOrFloat());
930 
931     Value zero = rewriter.create<arith::ConstantOp>(
932         loc, elemType, rewriter.getZeroAttr(elemType));
933     Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
934 
935     SmallVector<int64_t> sliceShape{castDstLastDim};
936     SmallVector<int64_t> strides{1};
937     VectorType newCastDstType =
938         VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
939                         castDstType.getElementType());
940 
941     for (int i = 0, e = shrinkRatio; i < e; ++i) {
942       Value extracted = rewriter.create<ExtractStridedSliceOp>(
943           loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
944           sliceShape, strides);
945       Value bitcast =
946           rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
947       res = rewriter.create<InsertStridedSliceOp>(
948           loc, bitcast, res,
949           ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
950     }
951     rewriter.replaceOp(bitcastOp, res);
952     return success();
953   }
954 
955 private:
956   std::function<bool(BitCastOp)> controlFn;
957 };
958 
959 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
960 /// ```
961 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
962 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
963 /// %r = arith.addi %a, %b : vector<1x4xindex>
964 /// ```
965 /// Gets converted to:
966 /// ```
967 /// %r = arith.addi %arg0, %arg1 : index
968 /// %b = vector.broadcast %r : index to vector<1x4xindex>
969 /// ```
970 ///
971 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
972 /// ops.
973 struct ReorderElementwiseOpsOnBroadcast final
974     : public OpTraitRewritePattern<OpTrait::Elementwise> {
975   using OpTraitRewritePattern::OpTraitRewritePattern;
976   LogicalResult matchAndRewrite(Operation *op,
977                                 PatternRewriter &rewriter) const override {
978     if (op->getNumResults() != 1)
979       return failure();
980     if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
981       return failure();
982     if (!OpTrait::hasElementwiseMappableTraits(op))
983       return failure();
984     if (op->getNumOperands() == 0 ||
985         op->getResults()[0].getType() != op->getOperand(0).getType()) {
986       return failure();
987     }
988     // Avoid operations that only accept vector types, since broadcast
989     // source might be scalar types.
990     if (isa<vector::FMAOp>(op)) {
991       return failure();
992     }
993 
994     // Get the type of the lhs operand
995     auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
996     if (!lhsBcastOrSplat ||
997         !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
998       return failure();
999     auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1000 
1001     // Make sure that all operands are broadcast from identical types:
1002     //  * scalar (`vector.broadcast` + `vector.splat`), or
1003     //  * vector (`vector.broadcast`).
1004     // Otherwise the re-ordering wouldn't be safe.
1005     if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
1006           auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1007           if (bcast)
1008             return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1009           auto splat = val.getDefiningOp<vector::SplatOp>();
1010           if (splat)
1011             return (splat.getOperand().getType() == lhsBcastOrSplatType);
1012           return false;
1013         })) {
1014       return failure();
1015     }
1016 
1017     // Collect the source values before broadcasting
1018     SmallVector<Value> srcValues;
1019     srcValues.reserve(op->getNumOperands());
1020     for (Value operand : op->getOperands()) {
1021       srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1022     }
1023 
1024     // Create the "elementwise" Op
1025     Operation *elementwiseOp =
1026         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1027                         lhsBcastOrSplatType, op->getAttrs());
1028 
1029     // Replace the original Op with the elementwise Op
1030     auto vectorType = op->getResultTypes()[0];
1031     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1032         op, vectorType, elementwiseOp->getResults());
1033 
1034     return success();
1035   }
1036 };
1037 
1038 // Helper that returns a vector comparison that constructs a mask:
1039 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1040 //
1041 // If `dim == 0` then the result will be a 0-D vector.
1042 //
1043 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1044 //       much more compact, IR for this operation, but LLVM eventually
1045 //       generates more elaborate instructions for this intrinsic since it
1046 //       is very conservative on the boundary conditions.
1047 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1048                                    bool force32BitVectorIndices, int64_t dim,
1049                                    Value b, Value *off = nullptr) {
1050   auto loc = op->getLoc();
1051   // If we can assume all indices fit in 32-bit, we perform the vector
1052   // comparison in 32-bit to get a higher degree of SIMD parallelism.
1053   // Otherwise we perform the vector comparison using 64-bit indices.
1054   Type idxType =
1055       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1056   DenseIntElementsAttr indicesAttr;
1057   if (dim == 0 && force32BitVectorIndices) {
1058     indicesAttr = DenseIntElementsAttr::get(
1059         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1060   } else if (dim == 0) {
1061     indicesAttr = DenseIntElementsAttr::get(
1062         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1063   } else if (force32BitVectorIndices) {
1064     indicesAttr = rewriter.getI32VectorAttr(
1065         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1066   } else {
1067     indicesAttr = rewriter.getI64VectorAttr(
1068         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1069   }
1070   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1071   // Add in an offset if requested.
1072   if (off) {
1073     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1074     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1075     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1076   }
1077   // Construct the vector comparison.
1078   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1079   Value bounds =
1080       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1081   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1082                                         bounds);
1083 }
1084 
1085 template <typename ConcreteOp>
1086 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1087 public:
1088   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1089                                    PatternBenefit benefit = 1)
1090       : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1091         force32BitVectorIndices(enableIndexOpt) {}
1092 
1093   LogicalResult matchAndRewrite(ConcreteOp xferOp,
1094                                 PatternRewriter &rewriter) const override {
1095     if (!xferOp.hasOutOfBoundsDim())
1096       return failure();
1097 
1098     if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1099       return failure();
1100 
1101     Location loc = xferOp->getLoc();
1102     VectorType vtp = xferOp.getVectorType();
1103 
1104     // Create the in-bounds mask with all elements between [0 .. dim - offset)
1105     // set and [dim - offset .. vector_length) unset.
1106     //
1107     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1108     //       dimensions here.
1109     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1110     Value off = xferOp.getIndices()[lastIndex];
1111     Value dim =
1112         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1113     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1114     Value mask = rewriter.create<vector::CreateMaskOp>(
1115         loc,
1116         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1117                         vtp.getScalableDims()),
1118         b);
1119     if (xferOp.getMask()) {
1120       // Intersect the in-bounds with the mask specified as an op parameter.
1121       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1122     }
1123 
1124     rewriter.modifyOpInPlace(xferOp, [&]() {
1125       xferOp.getMaskMutable().assign(mask);
1126       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1127     });
1128 
1129     return success();
1130   }
1131 
1132 private:
1133   const bool force32BitVectorIndices;
1134 };
1135 
1136 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1137 class VectorCreateMaskOpConversion
1138     : public OpRewritePattern<vector::CreateMaskOp> {
1139 public:
1140   explicit VectorCreateMaskOpConversion(MLIRContext *context,
1141                                         bool enableIndexOpt,
1142                                         PatternBenefit benefit = 1)
1143       : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1144         force32BitVectorIndices(enableIndexOpt) {}
1145 
1146   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1147                                 PatternRewriter &rewriter) const override {
1148     auto dstType = op.getType();
1149     if (cast<VectorType>(dstType).isScalable())
1150       return failure();
1151     int64_t rank = dstType.getRank();
1152     if (rank > 1)
1153       return failure();
1154     rewriter.replaceOp(
1155         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1156                                   rank == 0 ? 0 : dstType.getDimSize(0),
1157                                   op.getOperand(0)));
1158     return success();
1159   }
1160 
1161 private:
1162   const bool force32BitVectorIndices;
1163 };
1164 
1165 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1166 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1167   auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1168   // TODO: Support non-dense constant.
1169   if (!denseAttr)
1170     return false;
1171 
1172   assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1173   return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1174 }
1175 
1176 /// Folds a select operation between an all-true and all-false vector. For now,
1177 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1178 ///
1179 ///   %true = arith.constant dense<true> : vector<1xi1>
1180 ///   %false = arith.constant dense<false> : vector<1xi1>
1181 ///   %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1182 ///   =>
1183 ///   %result = vector.broadcast %cond : i1 to vector<1xi1>
1184 ///
1185 /// InstCombine seems to handle vectors with multiple elements but not the
1186 /// single element ones.
1187 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1188   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1189 
1190   LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1191                                 PatternRewriter &rewriter) const override {
1192     auto vecType = dyn_cast<VectorType>(selectOp.getType());
1193     if (!vecType || !vecType.getElementType().isInteger(1))
1194       return failure();
1195 
1196     // Only scalar conditions can be folded.
1197     Value cond = selectOp.getCondition();
1198     if (isa<VectorType>(cond.getType()))
1199       return failure();
1200 
1201     // TODO: Support n-D and scalable vectors.
1202     if (vecType.getRank() != 1 || vecType.isScalable())
1203       return failure();
1204 
1205     // TODO: Support vectors with multiple elements.
1206     if (vecType.getShape()[0] != 1)
1207       return failure();
1208 
1209     auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1210     if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1211       return failure();
1212 
1213     auto falseConst =
1214         selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1215     if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1216       return failure();
1217 
1218     // Replace select with its condition broadcasted to single element vector.
1219     auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1220     auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1221     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1222     return success();
1223   }
1224 };
1225 
1226 /// Returns the number of dims can be folded away from transfer ops. It returns
1227 /// a failure if it can not determine the number of dims to be folded.
1228 /// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
1229 /// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
1230 /// can be dropped by memref.subview ops.
1231 /// Example 2: it returns "1" if `srcType` is the same memref type with
1232 /// [8192, 16, 8, 1] strides.
1233 static FailureOr<size_t>
1234 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1235   SmallVector<int64_t> srcStrides;
1236   int64_t srcOffset;
1237   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1238     return failure();
1239 
1240   auto isUnitDim = [](VectorType type, int dim) {
1241     return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1242   };
1243 
1244   // According to vector.transfer_read/write semantics, the vector can be a
1245   // slice. Thus, we have to offset the check index with `rankDiff` in
1246   // `srcStrides` and source dim sizes.
1247   size_t result = 0;
1248   int rankDiff = srcType.getRank() - vectorType.getRank();
1249   for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1250     // Check that the inner dim size is 1 for both memref type and vector slice.
1251     // It can be folded only if they are 1 and the stride is 1.
1252     int dim = vectorType.getRank() - i - 1;
1253     if (srcStrides[dim + rankDiff] != 1 ||
1254         srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1255       break;
1256     result++;
1257   }
1258   return result;
1259 }
1260 
1261 /// Drop inner most contiguous unit dimensions from transfer_read operand.
1262 class DropInnerMostUnitDimsTransferRead
1263     : public OpRewritePattern<vector::TransferReadOp> {
1264   using OpRewritePattern::OpRewritePattern;
1265 
1266   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1267                                 PatternRewriter &rewriter) const override {
1268     // TODO: support 0-d corner case.
1269     if (readOp.getTransferRank() == 0)
1270       return failure();
1271 
1272     // TODO: support mask.
1273     if (readOp.getMask())
1274       return failure();
1275 
1276     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1277     if (!srcType)
1278       return failure();
1279 
1280     if (!readOp.getPermutationMap().isMinorIdentity())
1281       return failure();
1282 
1283     auto targetType = readOp.getVectorType();
1284     if (targetType.getRank() <= 1)
1285       return failure();
1286 
1287     FailureOr<size_t> maybeDimsToDrop =
1288         getTransferFoldableInnerUnitDims(srcType, targetType);
1289     if (failed(maybeDimsToDrop))
1290       return failure();
1291 
1292     size_t dimsToDrop = maybeDimsToDrop.value();
1293     if (dimsToDrop == 0)
1294       return failure();
1295 
1296     // Make sure that the indices to be dropped are equal 0.
1297     // TODO: Deal with cases when the indices are not 0.
1298     if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIndex))
1299       return failure();
1300 
1301     auto resultTargetVecType =
1302         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1303                         targetType.getElementType(),
1304                         targetType.getScalableDims().drop_back(dimsToDrop));
1305 
1306     auto loc = readOp.getLoc();
1307     SmallVector<OpFoldResult> sizes =
1308         memref::getMixedSizes(rewriter, loc, readOp.getSource());
1309     SmallVector<OpFoldResult> offsets(srcType.getRank(),
1310                                       rewriter.getIndexAttr(0));
1311     SmallVector<OpFoldResult> strides(srcType.getRank(),
1312                                       rewriter.getIndexAttr(1));
1313     auto resultMemrefType =
1314         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1315             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1316             strides));
1317     ArrayAttr inBoundsAttr =
1318         readOp.getInBounds()
1319             ? rewriter.getArrayAttr(
1320                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1321             : ArrayAttr();
1322     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1323         loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1324     auto permMap = getTransferMinorIdentityMap(
1325         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1326     Value result = rewriter.create<vector::TransferReadOp>(
1327         loc, resultTargetVecType, rankedReducedView,
1328         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1329         readOp.getPadding(),
1330         // TODO: support mask.
1331         /*mask=*/Value(), inBoundsAttr);
1332     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1333                                                      result);
1334     return success();
1335   }
1336 };
1337 
1338 /// Drop inner most contiguous unit dimensions from transfer_write operand.
1339 /// E.g.,
1340 ///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1341 ///      {in_bounds = [true, true, true, true, true]}
1342 ///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1343 ///
1344 /// will be replaced with
1345 ///
1346 ///    %subview = memref.subview %arg0
1347 ///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1348 ///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1349 ///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1350 ///      to vector<1x16x16xf32>
1351 ///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1352 ///      {in_bounds = [true, true, true]}
1353 ///      : vector<1x16x16xf32>, memref<1x512x16xf32>
1354 class DropInnerMostUnitDimsTransferWrite
1355     : public OpRewritePattern<vector::TransferWriteOp> {
1356   using OpRewritePattern::OpRewritePattern;
1357 
1358   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1359                                 PatternRewriter &rewriter) const override {
1360     // TODO: support 0-d corner case.
1361     if (writeOp.getTransferRank() == 0)
1362       return failure();
1363 
1364     // TODO: support mask.
1365     if (writeOp.getMask())
1366       return failure();
1367 
1368     auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1369     if (!srcType)
1370       return failure();
1371 
1372     if (!writeOp.getPermutationMap().isMinorIdentity())
1373       return failure();
1374 
1375     auto targetType = writeOp.getVectorType();
1376     if (targetType.getRank() <= 1)
1377       return failure();
1378 
1379     FailureOr<size_t> maybeDimsToDrop =
1380         getTransferFoldableInnerUnitDims(srcType, targetType);
1381     if (failed(maybeDimsToDrop))
1382       return failure();
1383 
1384     size_t dimsToDrop = maybeDimsToDrop.value();
1385     if (dimsToDrop == 0)
1386       return failure();
1387 
1388     auto resultTargetVecType =
1389         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1390                         targetType.getElementType(),
1391                         targetType.getScalableDims().drop_back(dimsToDrop));
1392 
1393     Location loc = writeOp.getLoc();
1394     SmallVector<OpFoldResult> sizes =
1395         memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1396     SmallVector<OpFoldResult> offsets(srcType.getRank(),
1397                                       rewriter.getIndexAttr(0));
1398     SmallVector<OpFoldResult> strides(srcType.getRank(),
1399                                       rewriter.getIndexAttr(1));
1400     auto resultMemrefType =
1401         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1402             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1403             strides));
1404     ArrayAttr inBoundsAttr =
1405         writeOp.getInBounds()
1406             ? rewriter.getArrayAttr(
1407                   writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1408             : ArrayAttr();
1409 
1410     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1411         loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1412     auto permMap = getTransferMinorIdentityMap(
1413         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1414 
1415     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1416         loc, resultTargetVecType, writeOp.getVector());
1417     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1418         writeOp, shapeCast, rankedReducedView,
1419         writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1420         // TODO: support mask.
1421         /*mask=*/Value(), inBoundsAttr);
1422     return success();
1423   }
1424 };
1425 
1426 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1427 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1428 /// with the RHS transposed) lowering.
1429 struct CanonicalizeContractMatmulToMMT final
1430     : OpRewritePattern<vector::ContractionOp> {
1431   using OpRewritePattern::OpRewritePattern;
1432 
1433   using FilterConstraintType =
1434       std::function<LogicalResult(vector::ContractionOp op)>;
1435 
1436   CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1437                                   FilterConstraintType constraint)
1438       : OpRewritePattern<vector::ContractionOp>(context, benefit),
1439         filter(std::move(constraint)) {}
1440 
1441   LogicalResult matchAndRewrite(vector::ContractionOp op,
1442                                 PatternRewriter &rewriter) const override {
1443     if (failed(filter(op)))
1444       return failure();
1445 
1446     Location loc = op.getLoc();
1447     Value lhs = op.getLhs();
1448     Value rhs = op.getRhs();
1449     Value res = op.getAcc();
1450 
1451     // Set up the parallel/reduction structure in right form.
1452     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1453     auto infer = [&](MapList m) {
1454       return AffineMap::inferFromExprList(m, op.getContext());
1455     };
1456     AffineExpr m;
1457     AffineExpr n;
1458     AffineExpr k;
1459     bindDims(rewriter.getContext(), m, n, k);
1460     static constexpr std::array<int64_t, 2> perm = {1, 0};
1461     auto iteratorTypes = op.getIteratorTypes().getValue();
1462     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1463     if (iteratorTypes.size() != 3 ||
1464         !vector::isParallelIterator(iteratorTypes[0]) ||
1465         !vector::isParallelIterator(iteratorTypes[1]) ||
1466         !vector::isReductionIterator(iteratorTypes[2]))
1467       return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1468 
1469     // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1470     const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1471     if (maps == canonicalForm)
1472       return rewriter.notifyMatchFailure(op, "already in the canonical form");
1473 
1474     // Create a vector transpose making sure to emit zero/sign-extend at the
1475     // end.
1476     auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1477       if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1478         Value trans =
1479             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1480         VectorType newType =
1481             cast<VectorType>(trans.getType())
1482                 .clone(cast<VectorType>(mat.getType()).getElementType());
1483         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1484       }
1485       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1486         Value trans =
1487             rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1488         VectorType newType =
1489             VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1490                             cast<VectorType>(mat.getType()).getElementType());
1491         return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1492       }
1493       return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1494     };
1495 
1496     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1497       rhs = createTranspose(rhs);
1498     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1499       lhs = createTranspose(lhs);
1500     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1501       rhs = createTranspose(rhs);
1502       lhs = createTranspose(lhs);
1503     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1504       std::swap(rhs, lhs);
1505       rhs = createTranspose(rhs);
1506       lhs = createTranspose(lhs);
1507     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1508       std::swap(rhs, lhs);
1509       rhs = createTranspose(rhs);
1510     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1511       std::swap(lhs, rhs);
1512       lhs = createTranspose(lhs);
1513     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1514       std::swap(lhs, rhs);
1515     } else {
1516       return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1517     }
1518     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1519         op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1520         op.getIteratorTypes());
1521     return success();
1522   };
1523 
1524 private:
1525   FilterConstraintType filter;
1526 };
1527 
1528 /// Pattern to fold arithmetic extensions on floating point data types into
1529 /// vector contraction operations. linalg.matmul introduces arithmetic
1530 /// extensions on its operands. Please mlir snippets below for more details.
1531 /// ```mlir
1532 ///   "linalg.matmul"(%lhs, %rhs, %acc) ({
1533 ///      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1534 ///        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1535 ///        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1536 ///        %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1537 ///        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1538 ///        "linalg.yield"(%acc) : (f32) -> ()
1539 ///     })
1540 /// ```
1541 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1542 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1543 /// This pattern folds the arithmetic extensions into the vector contraction and
1544 /// enables the usage of native mixed precision Tensor Core instructions.
1545 struct FoldArithExtIntoContractionOp
1546     : public OpRewritePattern<vector::ContractionOp> {
1547   using OpRewritePattern::OpRewritePattern;
1548 
1549   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1550                                 PatternRewriter &rewriter) const override {
1551 
1552     auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1553     auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1554 
1555     if (!lhsDefOp || !rhsDefOp) {
1556       return rewriter.notifyMatchFailure(contractOp,
1557                                          "no defining op on contract operands");
1558     }
1559 
1560     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1561         contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1562         contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1563         contractOp.getIteratorTypesAttr());
1564 
1565     return success();
1566   }
1567 };
1568 
1569 /// Pattern to fold chained reduction to a series of vector additions and a
1570 /// final reduction. This form should require fewer subgroup operations.
1571 ///
1572 /// ```mlir
1573 /// %a = vector.reduction <add> %x, %acc
1574 /// %b = vector.reduction <add> %y, %a
1575 ///  ==>
1576 /// %a = arith.addf %x, %y
1577 /// %b = vector.reduction <add> %a, %acc
1578 /// ```
1579 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1580   using OpRewritePattern::OpRewritePattern;
1581 
1582   LogicalResult matchAndRewrite(vector::ReductionOp op,
1583                                 PatternRewriter &rewriter) const override {
1584     // TODO: Handle other combining kinds.
1585     if (op.getKind() != vector::CombiningKind::ADD)
1586       return failure();
1587 
1588     // Accumulator is optional.
1589     Value acc = op.getAcc();
1590     if (!acc)
1591       return failure();
1592 
1593     if (!acc.getType().isIntOrFloat())
1594       return failure();
1595 
1596     auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1597     if (!parentReduction)
1598       return failure();
1599 
1600     Location loc = op.getLoc();
1601     Value vAdd;
1602     if (isa<IntegerType>(acc.getType())) {
1603       vAdd = rewriter.createOrFold<arith::AddIOp>(
1604           loc, parentReduction.getVector(), op.getVector());
1605     } else {
1606       vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1607                                             op.getVector());
1608     }
1609     rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1610                                                      parentReduction.getAcc());
1611     return success();
1612   }
1613 };
1614 
1615 // Scalable unit dimensions are not supported. Folding such dimensions would
1616 // require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1617 // vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1618 // future.
1619 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1620   auto inVecShape = inVecTy.getShape();
1621   SmallVector<int64_t> newShape;
1622   SmallVector<bool> newScalableDims;
1623   for (auto [dim, isScalable] :
1624        llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1625     if (dim == 1 && !isScalable)
1626       continue;
1627 
1628     newShape.push_back(dim);
1629     newScalableDims.push_back(isScalable);
1630   }
1631 
1632   return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1633 }
1634 
1635 /// For vectors with at least an unit dim, replaces:
1636 ///   elementwise(a, b)
1637 /// with:
1638 ///   sc_a = shape_cast(a)
1639 ///   sc_b = shape_cast(b)
1640 ///   res = elementwise(sc_a, sc_b)
1641 ///   return shape_cast(res)
1642 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1643 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1644 /// required to be rank > 1.
1645 ///
1646 /// Ex:
1647 ///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1648 ///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1649 ///
1650 /// gets converted to:
1651 ///
1652 ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1653 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1654 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1655 ///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1656 ///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1657 ///
1658 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1659 /// `%cast`.
1660 struct DropUnitDimFromElementwiseOps final
1661     : public OpTraitRewritePattern<OpTrait::Elementwise> {
1662   using OpTraitRewritePattern::OpTraitRewritePattern;
1663   LogicalResult matchAndRewrite(Operation *op,
1664                                 PatternRewriter &rewriter) const override {
1665     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1666       return failure();
1667 
1668     auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1669     if (!resultVectorType)
1670       return failure();
1671 
1672     // Check the operand pre-conditions. For `Elementwise` ops all operands are
1673     // guaranteed to have identical shapes (with some exceptions such as
1674     // `arith.select`) and it suffices to only check one of them.
1675     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1676     if (!sourceVectorType || sourceVectorType.getRank() < 2)
1677       return failure();
1678 
1679     SmallVector<Value> newOperands;
1680     auto loc = op->getLoc();
1681     for (auto operand : op->getOperands()) {
1682       auto opVectorType = cast<VectorType>(operand.getType());
1683       auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1684       if (newVType == opVectorType)
1685         return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1686 
1687       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1688       newOperands.push_back(opSC);
1689     }
1690 
1691     VectorType newResultVectorType =
1692         dropNonScalableUnitDimFromType(resultVectorType);
1693     // Create an updated elementwise Op without unit dim.
1694     Operation *elementwiseOp =
1695         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1696                         newResultVectorType, op->getAttrs());
1697 
1698     // Restore the unit dim by applying vector.shape_cast to the result.
1699     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1700                                              elementwiseOp->getResult(0));
1701 
1702     return success();
1703   }
1704 };
1705 
1706 /// Pattern to eliminate redundant zero-constants added to reduction operands.
1707 /// It's enough for there to be one initial zero value, so we can eliminate the
1708 /// extra ones that feed into `vector.reduction <add>`. These get created by the
1709 /// `ChainedReduction` pattern.
1710 ///
1711 /// ```mlir
1712 /// %a = arith.addf %x, %zero
1713 /// %b = arith.addf %a, %y
1714 /// %c = vector.reduction <add> %b, %acc
1715 ///  ==>
1716 /// %b = arith.addf %a, %y
1717 /// %c = vector.reduction <add> %b, %acc
1718 /// ```
1719 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1720   using OpRewritePattern::OpRewritePattern;
1721 
1722   LogicalResult matchAndRewrite(vector::ReductionOp op,
1723                                 PatternRewriter &rewriter) const override {
1724     // TODO: Handle other reduction kinds and their identity values.
1725     if (op.getKind() != vector::CombiningKind::ADD)
1726       return failure();
1727 
1728     Type elemType = op.getSourceVectorType().getElementType();
1729     // The integer case should be handled by `arith.addi` folders, only check
1730     // for floats here.
1731     if (!isa<FloatType>(elemType))
1732       return failure();
1733 
1734     auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1735     if (!vAdd)
1736       return failure();
1737     auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1738     if (!addLhs)
1739       return failure();
1740 
1741     if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1742       return failure();
1743 
1744     auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1745                                                  vAdd.getRhs());
1746     rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1747                                                      op.getAcc());
1748     return success();
1749   }
1750 };
1751 
1752 /// Example:
1753 /// ```
1754 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
1755 /// ```
1756 /// is transformed into:
1757 /// ```
1758 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
1759 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
1760 /// %a = arith.addf %y, %z : f32
1761 /// ```
1762 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1763   BreakDownVectorReduction(MLIRContext *context,
1764                            unsigned maxNumElementsToExtract,
1765                            PatternBenefit benefit)
1766       : OpRewritePattern(context, benefit),
1767         maxNumElementsToExtract(maxNumElementsToExtract) {}
1768 
1769   LogicalResult matchAndRewrite(vector::ReductionOp op,
1770                                 PatternRewriter &rewriter) const override {
1771     VectorType type = op.getSourceVectorType();
1772     if (type.isScalable() || op.isMasked())
1773       return failure();
1774     assert(type.getRank() == 1 && "Expected a 1-d vector");
1775 
1776     int64_t numElems = type.getNumElements();
1777     if (numElems > maxNumElementsToExtract) {
1778       return rewriter.notifyMatchFailure(
1779           op, llvm::formatv("has too many vector elements ({0}) to break down "
1780                             "(max allowed: {1})",
1781                             numElems, maxNumElementsToExtract));
1782     }
1783 
1784     Location loc = op.getLoc();
1785     SmallVector<Value> extracted(numElems, nullptr);
1786     for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1787       extractedElem = rewriter.create<vector::ExtractOp>(
1788           loc, op.getVector(), static_cast<int64_t>(idx));
1789 
1790     Value res = extracted.front();
1791     for (auto extractedElem : llvm::drop_begin(extracted))
1792       res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1793                                        extractedElem, op.getFastmathAttr());
1794     if (Value acc = op.getAcc())
1795       res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1796                                        op.getFastmathAttr());
1797 
1798     rewriter.replaceOp(op, res);
1799     return success();
1800   }
1801 
1802 private:
1803   unsigned maxNumElementsToExtract = 0;
1804 };
1805 
1806 } // namespace
1807 
1808 void mlir::vector::populateFoldArithExtensionPatterns(
1809     RewritePatternSet &patterns) {
1810   patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
1811 }
1812 
1813 void mlir::vector::populateVectorMaskMaterializationPatterns(
1814     RewritePatternSet &patterns, bool force32BitVectorIndices,
1815     PatternBenefit benefit) {
1816   patterns.add<VectorCreateMaskOpConversion,
1817                MaterializeTransferMask<vector::TransferReadOp>,
1818                MaterializeTransferMask<vector::TransferWriteOp>>(
1819       patterns.getContext(), force32BitVectorIndices, benefit);
1820   patterns.add<FoldI1Select>(patterns.getContext(), benefit);
1821 }
1822 
1823 void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
1824                                                     PatternBenefit benefit) {
1825   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
1826 }
1827 
1828 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
1829     RewritePatternSet &patterns, PatternBenefit benefit) {
1830   patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1831       patterns.getContext(), benefit);
1832 }
1833 
1834 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
1835     RewritePatternSet &patterns, PatternBenefit benefit) {
1836   patterns.add<BubbleDownVectorBitCastForExtract,
1837                BubbleDownBitCastForStridedSliceExtract,
1838                BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
1839       patterns.getContext(), benefit);
1840 }
1841 
1842 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
1843     RewritePatternSet &patterns,
1844     std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
1845   patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
1846                                        std::move(controlFn), benefit);
1847 }
1848 
1849 void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
1850     RewritePatternSet &patterns,
1851     std::function<LogicalResult(vector::ContractionOp)> constraint,
1852     PatternBenefit benefit) {
1853   patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
1854                                                 std::move(constraint));
1855 }
1856 
1857 void mlir::vector::populateVectorReductionToContractPatterns(
1858     RewritePatternSet &patterns, PatternBenefit benefit) {
1859   patterns.add<MultiReduceToContract, CombineContractBroadcast,
1860                CombineContractABTranspose, CombineContractResultTranspose,
1861                ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1862       patterns.getContext(), benefit);
1863 }
1864 
1865 void mlir::vector::
1866     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
1867         RewritePatternSet &patterns, PatternBenefit benefit) {
1868   patterns.add<DropInnerMostUnitDimsTransferRead,
1869                DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
1870                                                    benefit);
1871 }
1872 
1873 void mlir::vector::populateSinkVectorBroadcastPatterns(
1874     RewritePatternSet &patterns, PatternBenefit benefit) {
1875   patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1876       patterns.getContext(), benefit);
1877 }
1878 
1879 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
1880     RewritePatternSet &patterns, PatternBenefit benefit) {
1881   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
1882   patterns.add<ReduceRedundantZero>(patterns.getContext(),
1883                                     PatternBenefit(benefit.getBenefit() + 1));
1884 }
1885 
1886 void mlir::vector::populateBreakDownVectorReductionPatterns(
1887     RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
1888     PatternBenefit benefit) {
1889   patterns.add<BreakDownVectorReduction>(patterns.getContext(),
1890                                          maxNumElementsToExtract, benefit);
1891 }
1892 
1893 //===----------------------------------------------------------------------===//
1894 // TableGen'd enum attribute definitions
1895 //===----------------------------------------------------------------------===//
1896 
1897 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
1898