xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
199ef9eebSMatthias Springer //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements target-independent rewrites as 1->N patterns.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer 
139b5a3d14SMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
149b5a3d14SMatthias Springer 
1507677113SJakub Kuderski #include <cassert>
16f80a976aSJakub Kuderski #include <cstdint>
17fb7ef637SJakub Kuderski #include <functional>
185382d288SKazu Hirata #include <optional>
1999ef9eebSMatthias Springer #include <type_traits>
2099ef9eebSMatthias Springer 
2199ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
22abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
23abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
2499ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
2599ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
268b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
27f80a976aSJakub Kuderski #include "mlir/Dialect/Tensor/IR/Tensor.h"
28f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
2999ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
304758e916SOleg Shyshkov #include "mlir/Dialect/Vector/IR/VectorOps.h"
31fb7ef637SJakub Kuderski #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
329b5a3d14SMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
3383df43f3SAlex Zinenko #include "mlir/IR/BuiltinAttributeInterfaces.h"
344db65e27SLei Zhang #include "mlir/IR/BuiltinTypes.h"
3599ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
36f80a976aSJakub Kuderski #include "mlir/IR/Location.h"
3799ef9eebSMatthias Springer #include "mlir/IR/Matchers.h"
3899ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h"
39f80a976aSJakub Kuderski #include "mlir/IR/TypeUtilities.h"
4099ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
4199ef9eebSMatthias Springer 
4299ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
4399ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
4499ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
4599ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h"
4699ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
4707677113SJakub Kuderski #include "llvm/Support/FormatVariadic.h"
4899ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h"
4999ef9eebSMatthias Springer 
5099ef9eebSMatthias Springer #define DEBUG_TYPE "vector-to-vector"
5199ef9eebSMatthias Springer 
5299ef9eebSMatthias Springer using namespace mlir;
5399ef9eebSMatthias Springer using namespace mlir::vector;
5499ef9eebSMatthias Springer 
5599ef9eebSMatthias Springer template <typename IntType>
567a69a9d7SNicolas Vasilache static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
5799ef9eebSMatthias Springer   return llvm::to_vector<4>(llvm::map_range(
5899ef9eebSMatthias Springer       arrayAttr.getAsRange<IntegerAttr>(),
5999ef9eebSMatthias Springer       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
6099ef9eebSMatthias Springer }
6199ef9eebSMatthias Springer 
622bc4c3e9SNicolas Vasilache // Helper to find an index in an affine map.
632bc4c3e9SNicolas Vasilache static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
642bc4c3e9SNicolas Vasilache   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
652bc4c3e9SNicolas Vasilache     int64_t idx = map.getDimPosition(i);
662bc4c3e9SNicolas Vasilache     if (idx == index)
6789aaa2d0SThomas Raoux       return i;
6889aaa2d0SThomas Raoux   }
691a36588eSKazu Hirata   return std::nullopt;
7089aaa2d0SThomas Raoux }
7189aaa2d0SThomas Raoux 
7299ef9eebSMatthias Springer namespace {
7399ef9eebSMatthias Springer 
7499ef9eebSMatthias Springer /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
7599ef9eebSMatthias Springer //
7699ef9eebSMatthias Springer // Example:
7799ef9eebSMatthias Springer //
7899ef9eebSMatthias Springer //  The following MLIR with cancelling ShapeCastOps:
7999ef9eebSMatthias Springer //
8099ef9eebSMatthias Springer //   %0 = source : vector<5x4x2xf32>
8199ef9eebSMatthias Springer //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
8299ef9eebSMatthias Springer //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
8399ef9eebSMatthias Springer //   %3 = user %2 : vector<5x4x2xf32>
8499ef9eebSMatthias Springer //
8599ef9eebSMatthias Springer //  Should canonicalize to the following:
8699ef9eebSMatthias Springer //
8799ef9eebSMatthias Springer //   %0 = source : vector<5x4x2xf32>
8899ef9eebSMatthias Springer //   %1 = user %0 : vector<5x4x2xf32>
8999ef9eebSMatthias Springer //
9099ef9eebSMatthias Springer struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
9127cc31b6SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
9299ef9eebSMatthias Springer 
9399ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
9499ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
9599ef9eebSMatthias Springer     // Check if 'shapeCastOp' has vector source/result type.
9699ef9eebSMatthias Springer     auto sourceVectorType =
975550c821STres Popp         dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
9899ef9eebSMatthias Springer     auto resultVectorType =
995550c821STres Popp         dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
10099ef9eebSMatthias Springer     if (!sourceVectorType || !resultVectorType)
10199ef9eebSMatthias Springer       return failure();
10299ef9eebSMatthias Springer 
10399ef9eebSMatthias Springer     // Check if shape cast op source operand is also a shape cast op.
10499ef9eebSMatthias Springer     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
1057c38fd60SJacques Pienaar         shapeCastOp.getSource().getDefiningOp());
10699ef9eebSMatthias Springer     if (!sourceShapeCastOp)
10799ef9eebSMatthias Springer       return failure();
10899ef9eebSMatthias Springer     auto operandSourceVectorType =
1095550c821STres Popp         cast<VectorType>(sourceShapeCastOp.getSource().getType());
11099ef9eebSMatthias Springer     auto operandResultVectorType = sourceShapeCastOp.getType();
11199ef9eebSMatthias Springer 
11299ef9eebSMatthias Springer     // Check if shape cast operations invert each other.
11399ef9eebSMatthias Springer     if (operandSourceVectorType != resultVectorType ||
11499ef9eebSMatthias Springer         operandResultVectorType != sourceVectorType)
11599ef9eebSMatthias Springer       return failure();
11699ef9eebSMatthias Springer 
1177c38fd60SJacques Pienaar     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
11899ef9eebSMatthias Springer     return success();
11999ef9eebSMatthias Springer   }
12099ef9eebSMatthias Springer };
12199ef9eebSMatthias Springer 
12299ef9eebSMatthias Springer /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
12399ef9eebSMatthias Springer /// Ex:
12499ef9eebSMatthias Springer /// ```
12599ef9eebSMatthias Springer ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
12699ef9eebSMatthias Springer ///   %1 = vector.multi_reduction add, %0 [1]
12799ef9eebSMatthias Springer ///     : vector<8x32x16xf32> to vector<8x16xf32>
12899ef9eebSMatthias Springer /// ```
12999ef9eebSMatthias Springer /// Gets converted to:
13099ef9eebSMatthias Springer /// ```
13199ef9eebSMatthias Springer ///   %1 = vector.contract {indexing_maps = [
13299ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
13399ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
13499ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
13599ef9eebSMatthias Springer ///    iterator_types = ["parallel", "parallel", "reduction"],
13699ef9eebSMatthias Springer ///    kind = add} %0, %arg1, %cst_f0
13799ef9eebSMatthias Springer ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
13899ef9eebSMatthias Springer ///  ```
13999ef9eebSMatthias Springer struct MultiReduceToContract
14099ef9eebSMatthias Springer     : public OpRewritePattern<vector::MultiDimReductionOp> {
14127cc31b6SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
14299ef9eebSMatthias Springer 
14399ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
14499ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
1457c38fd60SJacques Pienaar     if (reduceOp.getKind() != vector::CombiningKind::ADD)
14699ef9eebSMatthias Springer       return failure();
1477c38fd60SJacques Pienaar     Operation *mulOp = reduceOp.getSource().getDefiningOp();
14899ef9eebSMatthias Springer     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
14999ef9eebSMatthias Springer       return failure();
15099ef9eebSMatthias Springer     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
15199ef9eebSMatthias Springer     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
15299ef9eebSMatthias Springer     SmallVector<AffineExpr> exprs;
1534758e916SOleg Shyshkov     SmallVector<vector::IteratorType> iteratorTypes;
15499ef9eebSMatthias Springer     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
15599ef9eebSMatthias Springer       if (!isReduceDim.value()) {
1564758e916SOleg Shyshkov         iteratorTypes.push_back(vector::IteratorType::parallel);
15799ef9eebSMatthias Springer         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
15899ef9eebSMatthias Springer       } else {
1594758e916SOleg Shyshkov         iteratorTypes.push_back(vector::IteratorType::reduction);
16099ef9eebSMatthias Springer       }
16199ef9eebSMatthias Springer     }
162fe8a62c4SUday Bondhugula     auto dstMap =
163fe8a62c4SUday Bondhugula         AffineMap::get(/*dimCount=*/reductionMask.size(),
164fe8a62c4SUday Bondhugula                        /*symbolCount=*/0, exprs, reduceOp.getContext());
16599ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
166051b36baSThomas Raoux         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
16799ef9eebSMatthias Springer         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
1684758e916SOleg Shyshkov         rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
1694758e916SOleg Shyshkov             iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
1704758e916SOleg Shyshkov               return IteratorTypeAttr::get(rewriter.getContext(), t);
1714758e916SOleg Shyshkov             }))));
17299ef9eebSMatthias Springer     return success();
17399ef9eebSMatthias Springer   }
17499ef9eebSMatthias Springer };
17599ef9eebSMatthias Springer 
176f0c93fd4SLei Zhang /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
17799ef9eebSMatthias Springer /// Ex:
17899ef9eebSMatthias Springer /// ```
17999ef9eebSMatthias Springer ///   %0 = vector.transpose %arg0, [2, 0, 1]
18099ef9eebSMatthias Springer ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
18199ef9eebSMatthias Springer ///   %1 = vector.contract {indexing_maps = [
18299ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
18399ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
18499ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
18599ef9eebSMatthias Springer ///    iterator_types = ["parallel", "parallel", "reduction"],
18699ef9eebSMatthias Springer ///    kind = add} %0, %arg1, %cst_f0
18799ef9eebSMatthias Springer ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
18899ef9eebSMatthias Springer /// ```
18999ef9eebSMatthias Springer /// Gets converted to:
19099ef9eebSMatthias Springer /// ```
19199ef9eebSMatthias Springer ///   %1 = vector.contract {indexing_maps = [
19299ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
19399ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
19499ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
19599ef9eebSMatthias Springer ///    iterator_types = ["parallel", "parallel", "reduction"],
19699ef9eebSMatthias Springer ///    kind = add} %arg0, %arg1, %cst_f0
19799ef9eebSMatthias Springer ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
19899ef9eebSMatthias Springer ///  ```
199f0c93fd4SLei Zhang struct CombineContractABTranspose final
20099ef9eebSMatthias Springer     : public OpRewritePattern<vector::ContractionOp> {
20127cc31b6SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
20299ef9eebSMatthias Springer 
20399ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
20499ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
2057a69a9d7SNicolas Vasilache     SmallVector<AffineMap> maps =
206d2c0572bSJacques Pienaar         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
2077c38fd60SJacques Pienaar     Value lhs = contractOp.getLhs();
2087c38fd60SJacques Pienaar     Value rhs = contractOp.getRhs();
20999ef9eebSMatthias Springer     size_t index = 0;
21099ef9eebSMatthias Springer     bool changed = false;
21199ef9eebSMatthias Springer     for (Value *operand : {&lhs, &rhs}) {
21299ef9eebSMatthias Springer       AffineMap &map = maps[index++];
21399ef9eebSMatthias Springer       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
21499ef9eebSMatthias Springer       if (!transposeOp)
21599ef9eebSMatthias Springer         continue;
21699ef9eebSMatthias Springer       AffineMap permutationMap = AffineMap::getPermutationMap(
21732c3decbSMatthias Springer           transposeOp.getPermutation(), contractOp.getContext());
21899ef9eebSMatthias Springer       map = inversePermutation(permutationMap).compose(map);
2197c38fd60SJacques Pienaar       *operand = transposeOp.getVector();
22099ef9eebSMatthias Springer       changed = true;
22199ef9eebSMatthias Springer     }
22299ef9eebSMatthias Springer     if (!changed)
22399ef9eebSMatthias Springer       return failure();
22499ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
2257c38fd60SJacques Pienaar         contractOp, lhs, rhs, contractOp.getAcc(),
2267c38fd60SJacques Pienaar         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
22799ef9eebSMatthias Springer     return success();
22899ef9eebSMatthias Springer   }
22999ef9eebSMatthias Springer };
23099ef9eebSMatthias Springer 
231f0c93fd4SLei Zhang /// Merges accumulator and result transposes into contract.
232f0c93fd4SLei Zhang ///
233f0c93fd4SLei Zhang /// For example:
234f0c93fd4SLei Zhang /// ```mlir
235f0c93fd4SLei Zhang /// %accT = vector.transpose %acc, [0, 2, 1]
236f0c93fd4SLei Zhang ///   : vector<2x8x4xf32> to vector<2x4x8xf32>
237f0c93fd4SLei Zhang /// %contract = vector.contract {
238f0c93fd4SLei Zhang ///   indexing_maps = [
239f0c93fd4SLei Zhang ///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
240f0c93fd4SLei Zhang ///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
241f0c93fd4SLei Zhang ///     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
242f0c93fd4SLei Zhang ///   ],
243f0c93fd4SLei Zhang ///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
244f0c93fd4SLei Zhang ///   kind = #vector.kind<add>
245f0c93fd4SLei Zhang /// } %lhs, %rhs, %accT
246f0c93fd4SLei Zhang ///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
247f0c93fd4SLei Zhang /// %0 = vector.transpose %contract, [0, 2, 1]
248f0c93fd4SLei Zhang ///   : vector<2x4x8xf32> to vector<2x8x4>
249f0c93fd4SLei Zhang /// ```
250f0c93fd4SLei Zhang /// Becomes:
251f0c93fd4SLei Zhang /// ```mlir
252f0c93fd4SLei Zhang /// %0 = vector.contract {
253f0c93fd4SLei Zhang ///   indexing_maps = [
254f0c93fd4SLei Zhang ///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
255f0c93fd4SLei Zhang ///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
256f0c93fd4SLei Zhang ///     affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
257f0c93fd4SLei Zhang ///   ],
258f0c93fd4SLei Zhang ///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
259f0c93fd4SLei Zhang ///   kind = #vector.kind<add>
260f0c93fd4SLei Zhang /// } %lhs, %rhs, %acc
261f0c93fd4SLei Zhang ///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
262f0c93fd4SLei Zhang /// ```
263f0c93fd4SLei Zhang struct CombineContractResultTranspose final
264f0c93fd4SLei Zhang     : public OpRewritePattern<vector::TransposeOp> {
265f0c93fd4SLei Zhang   using OpRewritePattern::OpRewritePattern;
266f0c93fd4SLei Zhang 
267f0c93fd4SLei Zhang   LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
268f0c93fd4SLei Zhang                                 PatternRewriter &rewriter) const override {
269f0c93fd4SLei Zhang     auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270f0c93fd4SLei Zhang     if (!contractOp || !contractOp->hasOneUse())
271f0c93fd4SLei Zhang       return failure();
272f0c93fd4SLei Zhang 
273f0c93fd4SLei Zhang     auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
274f0c93fd4SLei Zhang     if (!accTOp)
275f0c93fd4SLei Zhang       return failure();
276f0c93fd4SLei Zhang 
277f0c93fd4SLei Zhang     MLIRContext *context = contractOp.getContext();
278f0c93fd4SLei Zhang     auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
279f0c93fd4SLei Zhang     AffineMap contractMap = maps.back();
280f0c93fd4SLei Zhang 
281f0c93fd4SLei Zhang     // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
282f0c93fd4SLei Zhang     // To index into A in contract, we need revert(f)(g(C)) -> A.
28332c3decbSMatthias Springer     auto accTMap =
28432c3decbSMatthias Springer         AffineMap::getPermutationMap(accTOp.getPermutation(), context);
285f0c93fd4SLei Zhang 
286f0c93fd4SLei Zhang     // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
287f0c93fd4SLei Zhang     // To index into E in contract, we need h(g(C)) -> E.
28832c3decbSMatthias Springer     auto resTMap =
28932c3decbSMatthias Springer         AffineMap::getPermutationMap(resTOp.getPermutation(), context);
290f0c93fd4SLei Zhang     auto combinedResMap = resTMap.compose(contractMap);
291f0c93fd4SLei Zhang 
292f0c93fd4SLei Zhang     // The accumulator and result share the same indexing map. So they should be
293f0c93fd4SLei Zhang     // the same to be able to merge. This means combinedResMap is the same as
294f0c93fd4SLei Zhang     // inversePermutation(accTMap).compose(contractMap), which means
295f0c93fd4SLei Zhang     if (inversePermutation(accTMap) != resTMap)
296f0c93fd4SLei Zhang       return failure();
297f0c93fd4SLei Zhang     maps.back() = combinedResMap;
298f0c93fd4SLei Zhang 
299f0c93fd4SLei Zhang     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
300f0c93fd4SLei Zhang         resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
301f0c93fd4SLei Zhang         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
302f0c93fd4SLei Zhang     return success();
303f0c93fd4SLei Zhang   }
304f0c93fd4SLei Zhang };
305f0c93fd4SLei Zhang 
30699ef9eebSMatthias Springer /// Merge BroadcastOp into ContractionOp user.
30799ef9eebSMatthias Springer /// Ex:
30899ef9eebSMatthias Springer /// ```
30999ef9eebSMatthias Springer ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
31099ef9eebSMatthias Springer ///   %1 = vector.contract {indexing_maps = [
31199ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
31299ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
31399ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
31499ef9eebSMatthias Springer ///    iterator_types = ["parallel", "parallel", "reduction"],
31599ef9eebSMatthias Springer ///    kind = add} %0, %arg1, %cst_f0
31699ef9eebSMatthias Springer ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
31799ef9eebSMatthias Springer /// ```
31899ef9eebSMatthias Springer /// Gets converted to:
31999ef9eebSMatthias Springer /// ```
32099ef9eebSMatthias Springer ///   %1 = vector.contract {indexing_maps = [
32199ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
32299ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
32399ef9eebSMatthias Springer ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
32499ef9eebSMatthias Springer ///    iterator_types = ["parallel", "parallel", "reduction"],
32599ef9eebSMatthias Springer ///    kind = add} %arg0, %arg1, %cst_f0
32699ef9eebSMatthias Springer ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
32799ef9eebSMatthias Springer ///  ```
32899ef9eebSMatthias Springer struct CombineContractBroadcast
32999ef9eebSMatthias Springer     : public OpRewritePattern<vector::ContractionOp> {
33027cc31b6SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
33199ef9eebSMatthias Springer 
33299ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
33399ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
3347a69a9d7SNicolas Vasilache     SmallVector<AffineMap> maps =
335d2c0572bSJacques Pienaar         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
3367c38fd60SJacques Pienaar     Value lhs = contractOp.getLhs();
3377c38fd60SJacques Pienaar     Value rhs = contractOp.getRhs();
33899ef9eebSMatthias Springer     size_t index = 0;
33999ef9eebSMatthias Springer     bool changed = false;
34099ef9eebSMatthias Springer     for (Value *operand : {&lhs, &rhs}) {
34199ef9eebSMatthias Springer       AffineMap &map = maps[index++];
34299ef9eebSMatthias Springer       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
34399ef9eebSMatthias Springer       if (!broadcast)
34499ef9eebSMatthias Springer         continue;
34599ef9eebSMatthias Springer       // contractionOp can only take vector as operands.
3465550c821STres Popp       auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
347a1aad28dSLei Zhang       if (!srcType ||
348a1aad28dSLei Zhang           srcType.getRank() == broadcast.getResultVectorType().getRank())
34999ef9eebSMatthias Springer         continue;
35099ef9eebSMatthias Springer       int64_t rankDiff =
351a1aad28dSLei Zhang           broadcast.getResultVectorType().getRank() - srcType.getRank();
35299ef9eebSMatthias Springer       bool innerDimBroadcast = false;
35399ef9eebSMatthias Springer       SmallVector<AffineExpr> originalDims;
35499ef9eebSMatthias Springer       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
355a1aad28dSLei Zhang         if (dim.value() != broadcast.getResultVectorType().getDimSize(
356a1aad28dSLei Zhang                                rankDiff + dim.index())) {
35799ef9eebSMatthias Springer           innerDimBroadcast = true;
35899ef9eebSMatthias Springer           break;
35999ef9eebSMatthias Springer         }
36099ef9eebSMatthias Springer         originalDims.push_back(
36199ef9eebSMatthias Springer             rewriter.getAffineDimExpr(dim.index() + rankDiff));
36299ef9eebSMatthias Springer       }
36399ef9eebSMatthias Springer       // Contract doesn't support inner dimension broadcast. Once this is
36499ef9eebSMatthias Springer       // relaxed we can remove this case.
36599ef9eebSMatthias Springer       if (innerDimBroadcast)
36699ef9eebSMatthias Springer         continue;
367694ad3eaSBenoit Jacob 
368694ad3eaSBenoit Jacob       // It would be incorrect to fold a broadcast onto a reduction dimension
369694ad3eaSBenoit Jacob       // of non-unit size.
370694ad3eaSBenoit Jacob       bool nonUnitDimReductionBroadcast = false;
371694ad3eaSBenoit Jacob       for (int64_t i = 0; i < rankDiff; ++i) {
372a1aad28dSLei Zhang         if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
373694ad3eaSBenoit Jacob             isReductionIterator(contractOp.getIteratorTypes()
374694ad3eaSBenoit Jacob                                     .getValue()[map.getDimPosition(i)])) {
375694ad3eaSBenoit Jacob           nonUnitDimReductionBroadcast = true;
376694ad3eaSBenoit Jacob           break;
377694ad3eaSBenoit Jacob         }
378694ad3eaSBenoit Jacob       }
379694ad3eaSBenoit Jacob       if (nonUnitDimReductionBroadcast)
380694ad3eaSBenoit Jacob         continue;
381694ad3eaSBenoit Jacob 
38299ef9eebSMatthias Springer       AffineMap broadcastMap =
383a1aad28dSLei Zhang           AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
384a1aad28dSLei Zhang                          originalDims, contractOp.getContext());
38599ef9eebSMatthias Springer       map = broadcastMap.compose(map);
3867c38fd60SJacques Pienaar       *operand = broadcast.getSource();
38799ef9eebSMatthias Springer       changed = true;
38899ef9eebSMatthias Springer     }
389694ad3eaSBenoit Jacob 
39099ef9eebSMatthias Springer     if (!changed)
39199ef9eebSMatthias Springer       return failure();
392694ad3eaSBenoit Jacob 
393694ad3eaSBenoit Jacob     // Determine which dims are usused, now that the maps have been composed
394694ad3eaSBenoit Jacob     // with the broadcast maps.
395c3839c0bSBenoit Jacob     llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
396694ad3eaSBenoit Jacob     // Compress unused dims.
397694ad3eaSBenoit Jacob     for (auto &m : maps)
398c3839c0bSBenoit Jacob       m = compressDims(m, unusedDimsBitVector);
399694ad3eaSBenoit Jacob     // Compute the combined iterators.
4007a69a9d7SNicolas Vasilache     SmallVector<Attribute> iterators;
401c3839c0bSBenoit Jacob     for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402c3839c0bSBenoit Jacob       if (!unusedDimsBitVector.test(i))
403694ad3eaSBenoit Jacob         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
404694ad3eaSBenoit Jacob     }
405f0c3fd18SBenoit Jacob     // Check that compressing unused dims isn't removing all reduction dimension
406f0c3fd18SBenoit Jacob     // pairs. For example, if the vector.contract had only one reduction
407694ad3eaSBenoit Jacob     // iterator and that was a unit-dimension created by a broadcast,
408694ad3eaSBenoit Jacob     // then we should bail here, otherwise we would create a contract without
409f0c3fd18SBenoit Jacob     // a reduction dimension pair.
410f0c3fd18SBenoit Jacob     bool hasReductionIteratorApplyingOnBothSides = false;
411f0c3fd18SBenoit Jacob     for (unsigned i = 0; i < iterators.size(); ++i) {
412f0c3fd18SBenoit Jacob       if (!isReductionIterator(iterators[i]))
413f0c3fd18SBenoit Jacob         continue;
414f0c3fd18SBenoit Jacob       if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
415f0c3fd18SBenoit Jacob         hasReductionIteratorApplyingOnBothSides = true;
416f0c3fd18SBenoit Jacob         break;
417f0c3fd18SBenoit Jacob       }
418f0c3fd18SBenoit Jacob     }
419f0c3fd18SBenoit Jacob     if (!hasReductionIteratorApplyingOnBothSides)
420694ad3eaSBenoit Jacob       return failure();
421f0c3fd18SBenoit Jacob 
422c3839c0bSBenoit Jacob     // If the compressed maps have a dimension that is not used by either LHS or
423c3839c0bSBenoit Jacob     // RHS then the ContractionOp verifier would fail.
424c3839c0bSBenoit Jacob     if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
425c3839c0bSBenoit Jacob       return failure();
42699ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
4277c38fd60SJacques Pienaar         contractOp, lhs, rhs, contractOp.getAcc(),
428694ad3eaSBenoit Jacob         rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
42999ef9eebSMatthias Springer     return success();
43099ef9eebSMatthias Springer   }
43199ef9eebSMatthias Springer };
43299ef9eebSMatthias Springer 
4331538bd51SHanhan Wang /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
4341538bd51SHanhan Wang /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
4351538bd51SHanhan Wang /// casting ops are around these operations.
4361538bd51SHanhan Wang /// Ex:
4371538bd51SHanhan Wang /// ```
4381538bd51SHanhan Wang ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
4391538bd51SHanhan Wang ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
4401538bd51SHanhan Wang /// ```
4411538bd51SHanhan Wang /// Gets converted to:
4421538bd51SHanhan Wang /// ```
4431538bd51SHanhan Wang ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
4441538bd51SHanhan Wang ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
4451538bd51SHanhan Wang /// ```
4461538bd51SHanhan Wang struct ReorderCastOpsOnBroadcast
4471538bd51SHanhan Wang     : public OpInterfaceRewritePattern<CastOpInterface> {
4481538bd51SHanhan Wang   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
4491538bd51SHanhan Wang 
4501538bd51SHanhan Wang   LogicalResult matchAndRewrite(CastOpInterface op,
4511538bd51SHanhan Wang                                 PatternRewriter &rewriter) const override {
4521538bd51SHanhan Wang     if (op->getNumOperands() != 1)
4531538bd51SHanhan Wang       return failure();
4541538bd51SHanhan Wang     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
4551538bd51SHanhan Wang     if (!bcastOp)
4561538bd51SHanhan Wang       return failure();
4571538bd51SHanhan Wang 
4581538bd51SHanhan Wang     Type castResTy = getElementTypeOrSelf(op->getResult(0));
4595550c821STres Popp     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
4604b2ba5a6SAndrzej Warzyński       castResTy = vecTy.clone(castResTy);
46135f48edbSMehdi Amini     auto *castOp =
4627c38fd60SJacques Pienaar         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
4637c38fd60SJacques Pienaar                         bcastOp.getSource(), castResTy, op->getAttrs());
4641538bd51SHanhan Wang     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4651538bd51SHanhan Wang         op, op->getResult(0).getType(), castOp->getResult(0));
4661538bd51SHanhan Wang     return success();
4671538bd51SHanhan Wang   }
4681538bd51SHanhan Wang };
4691538bd51SHanhan Wang 
4704db65e27SLei Zhang /// Reorders elementwise(transpose) to transpose(elementwise). This makes
4714db65e27SLei Zhang /// transpose ops and contraction ops closer, which kicks in
472f0c93fd4SLei Zhang /// CombineContractABTranspose pattern when elementwise ops are between these
4734db65e27SLei Zhang /// operations. Ex:
4741538bd51SHanhan Wang /// ```
4754db65e27SLei Zhang /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
4764db65e27SLei Zhang /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
4774db65e27SLei Zhang /// %r = arith.addf %at, %bt : vector<2x4xf32>
4781538bd51SHanhan Wang /// ```
4791538bd51SHanhan Wang /// Gets converted to:
4801538bd51SHanhan Wang /// ```
4814db65e27SLei Zhang /// %0 = arith.addf %a, %b : vector<4x2xf32>
4824db65e27SLei Zhang /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
4831538bd51SHanhan Wang /// ```
4844db65e27SLei Zhang struct ReorderElementwiseOpsOnTranspose final
4854db65e27SLei Zhang     : public OpTraitRewritePattern<OpTrait::Elementwise> {
4864db65e27SLei Zhang   using OpTraitRewritePattern::OpTraitRewritePattern;
4874db65e27SLei Zhang   LogicalResult matchAndRewrite(Operation *op,
4881538bd51SHanhan Wang                                 PatternRewriter &rewriter) const override {
4894db65e27SLei Zhang     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
4901538bd51SHanhan Wang       return failure();
4911538bd51SHanhan Wang 
4924db65e27SLei Zhang     // Make sure all operands are transpose/constant ops and collect their
4934db65e27SLei Zhang     // transposition maps.
49432c3decbSMatthias Springer     SmallVector<ArrayRef<int64_t>> transposeMaps;
4954db65e27SLei Zhang     transposeMaps.reserve(op->getNumOperands());
4964db65e27SLei Zhang     // Record the initial type before transposition. We'll use its shape later.
4974db65e27SLei Zhang     // Any type will do here as we will check all transpose maps are the same.
4984db65e27SLei Zhang     VectorType srcType;
4994db65e27SLei Zhang     for (Value operand : op->getOperands()) {
5004db65e27SLei Zhang       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
5014db65e27SLei Zhang       if (transposeOp) {
50232c3decbSMatthias Springer         transposeMaps.push_back(transposeOp.getPermutation());
503a1aad28dSLei Zhang         srcType = transposeOp.getSourceVectorType();
5044db65e27SLei Zhang       } else if (!matchPattern(operand, m_Constant())) {
5054db65e27SLei Zhang         return failure();
5064db65e27SLei Zhang       }
5074db65e27SLei Zhang     }
5084db65e27SLei Zhang     if (transposeMaps.empty())
5094db65e27SLei Zhang       return failure();
5104db65e27SLei Zhang     // This is an elementwise op, so all transposed operands should have the
5114db65e27SLei Zhang     // same type. We need to additionally check that all transposes uses the
5124db65e27SLei Zhang     // same map.
5136fa87ec1SJakub Kuderski     if (!llvm::all_equal(transposeMaps))
5144db65e27SLei Zhang       return rewriter.notifyMatchFailure(op, "different transpose map");
5154db65e27SLei Zhang 
5167a69a9d7SNicolas Vasilache     SmallVector<Value> srcValues;
5174db65e27SLei Zhang     srcValues.reserve(op->getNumOperands());
5184db65e27SLei Zhang 
5194db65e27SLei Zhang     // If there are constant operands, we need to insert inverse transposes for
5204db65e27SLei Zhang     // them. Calculate the inverse order first.
52132c3decbSMatthias Springer     auto order = transposeMaps.front();
5224db65e27SLei Zhang     SmallVector<int64_t> invOrder(order.size());
5234db65e27SLei Zhang     for (int i = 0, e = order.size(); i < e; ++i)
5244db65e27SLei Zhang       invOrder[order[i]] = i;
5254db65e27SLei Zhang 
5264db65e27SLei Zhang     for (Value operand : op->getOperands()) {
5274db65e27SLei Zhang       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
5284db65e27SLei Zhang       if (transposeOp) {
5294db65e27SLei Zhang         srcValues.push_back(transposeOp.getVector());
5304db65e27SLei Zhang       } else {
5314db65e27SLei Zhang         // This is a constant. Create a reverse transpose op for it.
5324b2ba5a6SAndrzej Warzyński         auto vectorType =
5334b2ba5a6SAndrzej Warzyński             srcType.clone(cast<VectorType>(operand.getType()).getElementType());
5344db65e27SLei Zhang         srcValues.push_back(rewriter.create<vector::TransposeOp>(
53532c3decbSMatthias Springer             operand.getLoc(), vectorType, operand, invOrder));
5364db65e27SLei Zhang       }
5374db65e27SLei Zhang     }
5384db65e27SLei Zhang 
5394b2ba5a6SAndrzej Warzyński     auto vectorType = srcType.clone(
5405550c821STres Popp         cast<VectorType>(op->getResultTypes()[0]).getElementType());
5414db65e27SLei Zhang     Operation *elementwiseOp =
5424db65e27SLei Zhang         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
5434db65e27SLei Zhang                         vectorType, op->getAttrs());
5441538bd51SHanhan Wang     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5454db65e27SLei Zhang         op, op->getResultTypes()[0], elementwiseOp->getResult(0),
5464db65e27SLei Zhang         transposeMaps.front());
5471538bd51SHanhan Wang     return success();
5481538bd51SHanhan Wang   }
5491538bd51SHanhan Wang };
5501538bd51SHanhan Wang 
55199ef9eebSMatthias Springer // Returns the values in `arrayAttr` as an integer vector.
5527a69a9d7SNicolas Vasilache static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
55399ef9eebSMatthias Springer   return llvm::to_vector<4>(
55499ef9eebSMatthias Springer       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
55599ef9eebSMatthias Springer                       [](IntegerAttr attr) { return attr.getInt(); }));
55699ef9eebSMatthias Springer }
55799ef9eebSMatthias Springer 
55899ef9eebSMatthias Springer // Shuffles vector.bitcast op after vector.extract op.
55999ef9eebSMatthias Springer //
56099ef9eebSMatthias Springer // This transforms IR like:
56199ef9eebSMatthias Springer //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
5629816edc9SCullen Rhodes //   %1 = vector.extract %0[3] : f16 from vector<8xf16>
56399ef9eebSMatthias Springer // Into:
5649816edc9SCullen Rhodes //   %0 = vector.extract %src[1] : f32 from vector<4xf32>
56599ef9eebSMatthias Springer //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
5669816edc9SCullen Rhodes //   %2 = vector.extract %1[1] : f16 from vector<2xf16>
56799ef9eebSMatthias Springer struct BubbleDownVectorBitCastForExtract
56899ef9eebSMatthias Springer     : public OpRewritePattern<vector::ExtractOp> {
56999ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
57099ef9eebSMatthias Springer 
57199ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
57299ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
57399ef9eebSMatthias Springer     // Only support extracting scalars for now.
574a1aad28dSLei Zhang     if (extractOp.getSourceVectorType().getRank() != 1)
57599ef9eebSMatthias Springer       return failure();
57699ef9eebSMatthias Springer 
5777c38fd60SJacques Pienaar     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
57899ef9eebSMatthias Springer     if (!castOp)
57999ef9eebSMatthias Springer       return failure();
58099ef9eebSMatthias Springer 
58199ef9eebSMatthias Springer     VectorType castSrcType = castOp.getSourceVectorType();
58299ef9eebSMatthias Springer     VectorType castDstType = castOp.getResultVectorType();
58399ef9eebSMatthias Springer     assert(castSrcType.getRank() == castDstType.getRank());
58499ef9eebSMatthias Springer 
58599ef9eebSMatthias Springer     // Fail to match if we only have one element in the cast op source.
58699ef9eebSMatthias Springer     // This is to avoid infinite loop given that this pattern can generate
58799ef9eebSMatthias Springer     // such cases.
58899ef9eebSMatthias Springer     if (castSrcType.getNumElements() == 1)
58999ef9eebSMatthias Springer       return failure();
59099ef9eebSMatthias Springer 
59199ef9eebSMatthias Springer     // Only support casting to a larger number of elements or now.
59299ef9eebSMatthias Springer     // E.g., vector<4xf32> -> vector<8xf16>.
59399ef9eebSMatthias Springer     if (castSrcType.getNumElements() > castDstType.getNumElements())
59499ef9eebSMatthias Springer       return failure();
59599ef9eebSMatthias Springer 
59699ef9eebSMatthias Springer     unsigned expandRatio =
59799ef9eebSMatthias Springer         castDstType.getNumElements() / castSrcType.getNumElements();
59899ef9eebSMatthias Springer 
5996626ed6fSlialan     // Get the first element of the mixed position as integer.
6006626ed6fSlialan     auto mixedPos = extractOp.getMixedPosition();
6016e41483bSKazu Hirata     if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
6026626ed6fSlialan       return failure();
6036e41483bSKazu Hirata     uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
60499ef9eebSMatthias Springer 
60599ef9eebSMatthias Springer     // Get the single scalar (as a vector) in the source value that packs the
60699ef9eebSMatthias Springer     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
60798f6289aSDiego Caballero     Location loc = extractOp.getLoc();
60899ef9eebSMatthias Springer     Value packedValue = rewriter.create<vector::ExtractOp>(
60998f6289aSDiego Caballero         loc, castOp.getSource(), index / expandRatio);
61098f6289aSDiego Caballero     Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
61198f6289aSDiego Caballero     Value zero = rewriter.create<arith::ConstantOp>(
61298f6289aSDiego Caballero         loc, packedVecType, rewriter.getZeroAttr(packedVecType));
61398f6289aSDiego Caballero     packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
61498f6289aSDiego Caballero                                                     /*position=*/0);
61599ef9eebSMatthias Springer 
61699ef9eebSMatthias Springer     // Cast it to a vector with the desired scalar's type.
61799ef9eebSMatthias Springer     // E.g. f32 -> vector<2xf16>
61899ef9eebSMatthias Springer     VectorType packedType =
61999ef9eebSMatthias Springer         VectorType::get({expandRatio}, castDstType.getElementType());
62098f6289aSDiego Caballero     Value castedValue =
62198f6289aSDiego Caballero         rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
62299ef9eebSMatthias Springer 
62399ef9eebSMatthias Springer     // Finally extract the desired scalar.
62498f6289aSDiego Caballero     rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
62598f6289aSDiego Caballero                                                    index % expandRatio);
62699ef9eebSMatthias Springer     return success();
62799ef9eebSMatthias Springer   }
62899ef9eebSMatthias Springer };
62999ef9eebSMatthias Springer 
63099ef9eebSMatthias Springer // Shuffles vector.bitcast op after vector.extract_strided_slice op.
63199ef9eebSMatthias Springer //
63299ef9eebSMatthias Springer // This transforms IR like:
63399ef9eebSMatthias Springer //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
63499ef9eebSMatthias Springer //     %0 = vector.extract_strided_slice %cast {
63599ef9eebSMatthias Springer //            offsets = [4], sizes = [4], strides = [1]
63699ef9eebSMatthias Springer //          } : vector<8xf16> to vector<4xf16>
63799ef9eebSMatthias Springer // Into:
63899ef9eebSMatthias Springer //   %0 = vector.extract_strided_slice %src {
63999ef9eebSMatthias Springer //          offsets = [2], sizes = [2], strides = [1]
64099ef9eebSMatthias Springer //        } : vector<4xf32> to vector<2xf32>
64199ef9eebSMatthias Springer //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
64299ef9eebSMatthias Springer struct BubbleDownBitCastForStridedSliceExtract
64399ef9eebSMatthias Springer     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
64499ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
64599ef9eebSMatthias Springer 
64699ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
64799ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
6487c38fd60SJacques Pienaar     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
64999ef9eebSMatthias Springer     if (!castOp)
65099ef9eebSMatthias Springer       return failure();
65199ef9eebSMatthias Springer 
65299ef9eebSMatthias Springer     VectorType castSrcType = castOp.getSourceVectorType();
65399ef9eebSMatthias Springer     VectorType castDstType = castOp.getResultVectorType();
65499ef9eebSMatthias Springer     assert(castSrcType.getRank() == castDstType.getRank());
65599ef9eebSMatthias Springer 
65699ef9eebSMatthias Springer     int64_t castSrcLastDim = castSrcType.getShape().back();
65799ef9eebSMatthias Springer     int64_t castDstLastDim = castDstType.getShape().back();
65899ef9eebSMatthias Springer     // Require casting to more elements for now; other cases to be implemented.
65999ef9eebSMatthias Springer     if (castSrcLastDim > castDstLastDim)
66099ef9eebSMatthias Springer       return failure();
66199ef9eebSMatthias Springer 
66299ef9eebSMatthias Springer     // Only accept all one strides for now.
6637c38fd60SJacques Pienaar     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
6649e5d2495SKazu Hirata                      [](const APInt &val) { return !val.isOne(); }))
66599ef9eebSMatthias Springer       return failure();
66699ef9eebSMatthias Springer 
667a1aad28dSLei Zhang     unsigned rank = extractOp.getSourceVectorType().getRank();
66899ef9eebSMatthias Springer     assert(castDstLastDim % castSrcLastDim == 0);
66999ef9eebSMatthias Springer     int64_t expandRatio = castDstLastDim / castSrcLastDim;
67099ef9eebSMatthias Springer 
67199ef9eebSMatthias Springer     // If we have a less number of offsets than the rank, then implicitly we
67299ef9eebSMatthias Springer     // are selecting the full range for the last bitcasted dimension; other
67399ef9eebSMatthias Springer     // dimensions aren't affected. Otherwise, we need to scale down the last
67499ef9eebSMatthias Springer     // dimension's offset given we are extracting from less elements now.
6757c38fd60SJacques Pienaar     ArrayAttr newOffsets = extractOp.getOffsets();
67699ef9eebSMatthias Springer     if (newOffsets.size() == rank) {
6777a69a9d7SNicolas Vasilache       SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
67899ef9eebSMatthias Springer       if (offsets.back() % expandRatio != 0)
67999ef9eebSMatthias Springer         return failure();
68099ef9eebSMatthias Springer       offsets.back() = offsets.back() / expandRatio;
68199ef9eebSMatthias Springer       newOffsets = rewriter.getI64ArrayAttr(offsets);
68299ef9eebSMatthias Springer     }
68399ef9eebSMatthias Springer 
68499ef9eebSMatthias Springer     // Similarly for sizes.
6857c38fd60SJacques Pienaar     ArrayAttr newSizes = extractOp.getSizes();
68699ef9eebSMatthias Springer     if (newSizes.size() == rank) {
6877a69a9d7SNicolas Vasilache       SmallVector<int64_t> sizes = getIntValueVector(newSizes);
68899ef9eebSMatthias Springer       if (sizes.back() % expandRatio != 0)
68999ef9eebSMatthias Springer         return failure();
69099ef9eebSMatthias Springer       sizes.back() = sizes.back() / expandRatio;
69199ef9eebSMatthias Springer       newSizes = rewriter.getI64ArrayAttr(sizes);
69299ef9eebSMatthias Springer     }
69399ef9eebSMatthias Springer 
6947a69a9d7SNicolas Vasilache     SmallVector<int64_t> dims =
6955550c821STres Popp         llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
69699ef9eebSMatthias Springer     dims.back() = dims.back() / expandRatio;
69799ef9eebSMatthias Springer     VectorType newExtractType =
69899ef9eebSMatthias Springer         VectorType::get(dims, castSrcType.getElementType());
69999ef9eebSMatthias Springer 
70099ef9eebSMatthias Springer     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
7017c38fd60SJacques Pienaar         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
7027c38fd60SJacques Pienaar         newSizes, extractOp.getStrides());
70399ef9eebSMatthias Springer 
70499ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
70599ef9eebSMatthias Springer         extractOp, extractOp.getType(), newExtractOp);
70699ef9eebSMatthias Springer 
70799ef9eebSMatthias Springer     return success();
70899ef9eebSMatthias Springer   }
70999ef9eebSMatthias Springer };
71099ef9eebSMatthias Springer 
71199ef9eebSMatthias Springer // Shuffles vector.bitcast op before vector.insert_strided_slice op.
71299ef9eebSMatthias Springer //
71399ef9eebSMatthias Springer // This transforms IR like:
7144623c114SDiego Caballero //   %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
7154623c114SDiego Caballero //   %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
7164623c114SDiego Caballero // Into:
7174623c114SDiego Caballero //   %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
7184623c114SDiego Caballero //   %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
7194623c114SDiego Caballero //   %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
7204623c114SDiego Caballero //
7214623c114SDiego Caballero struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
7224623c114SDiego Caballero   using OpRewritePattern::OpRewritePattern;
7234623c114SDiego Caballero 
7244623c114SDiego Caballero   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
7254623c114SDiego Caballero                                 PatternRewriter &rewriter) const override {
7264623c114SDiego Caballero     VectorType castSrcType = bitcastOp.getSourceVectorType();
7274623c114SDiego Caballero     VectorType castDstType = bitcastOp.getResultVectorType();
7284623c114SDiego Caballero 
7294623c114SDiego Caballero     // 0-D and scalable vectors are not supported yet.
7304623c114SDiego Caballero     if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
7314623c114SDiego Caballero         castDstType.isScalable())
7324623c114SDiego Caballero       return failure();
7334623c114SDiego Caballero 
7344623c114SDiego Caballero     int64_t castSrcLastDim = castSrcType.getShape().back();
7354623c114SDiego Caballero     int64_t castDstLastDim = castDstType.getShape().back();
7364623c114SDiego Caballero     bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
7374623c114SDiego Caballero     int64_t ratio;
7384623c114SDiego Caballero     if (isNumElemsShrink) {
7394623c114SDiego Caballero       assert(castSrcLastDim % castDstLastDim == 0);
7404623c114SDiego Caballero       ratio = castSrcLastDim / castDstLastDim;
7414623c114SDiego Caballero     } else {
7424623c114SDiego Caballero       assert(castDstLastDim % castSrcLastDim == 0);
7434623c114SDiego Caballero       ratio = castDstLastDim / castSrcLastDim;
7444623c114SDiego Caballero     }
7454623c114SDiego Caballero 
7464623c114SDiego Caballero     auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
7474623c114SDiego Caballero     if (!insertOp)
7484623c114SDiego Caballero       return failure();
7494623c114SDiego Caballero 
7504623c114SDiego Caballero     // Only vector sources are supported for now.
7514623c114SDiego Caballero     auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
7524623c114SDiego Caballero     if (!insertSrcType)
7534623c114SDiego Caballero       return failure();
7544623c114SDiego Caballero 
7554623c114SDiego Caballero     // Bitcast the source.
7564623c114SDiego Caballero     SmallVector<int64_t> srcDims(insertSrcType.getShape());
7574623c114SDiego Caballero     srcDims.back() =
7584623c114SDiego Caballero         isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
7594623c114SDiego Caballero     VectorType newCastSrcType =
7604623c114SDiego Caballero         VectorType::get(srcDims, castDstType.getElementType());
7614623c114SDiego Caballero     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
7624623c114SDiego Caballero         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
7634623c114SDiego Caballero 
7644623c114SDiego Caballero     SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
7654623c114SDiego Caballero     dstDims.back() =
7664623c114SDiego Caballero         isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
7674623c114SDiego Caballero     VectorType newCastDstType =
7684623c114SDiego Caballero         VectorType::get(dstDims, castDstType.getElementType());
7694623c114SDiego Caballero 
7704623c114SDiego Caballero     // Bitcast the destination.
7714623c114SDiego Caballero     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
7724623c114SDiego Caballero         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
7734623c114SDiego Caballero 
7744623c114SDiego Caballero     // Generate new insert.
7754623c114SDiego Caballero     rewriter.replaceOpWithNewOp<vector::InsertOp>(
7764623c114SDiego Caballero         bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
7774623c114SDiego Caballero     return success();
7784623c114SDiego Caballero   }
7794623c114SDiego Caballero };
7804623c114SDiego Caballero 
7814623c114SDiego Caballero // Shuffles vector.bitcast op before vector.insert_strided_slice op.
7824623c114SDiego Caballero //
7834623c114SDiego Caballero // This transforms IR like:
78499ef9eebSMatthias Springer //   %0 = vector.insert_strided_slice %src, %dst {
78599ef9eebSMatthias Springer //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
78699ef9eebSMatthias Springer //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
78799ef9eebSMatthias Springer // Into:
78899ef9eebSMatthias Springer //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
78999ef9eebSMatthias Springer //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
79099ef9eebSMatthias Springer //   %2 = vector.insert_strided_slice %src, %dst {
79199ef9eebSMatthias Springer //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
79299ef9eebSMatthias Springer struct BubbleUpBitCastForStridedSliceInsert
79399ef9eebSMatthias Springer     : public OpRewritePattern<vector::BitCastOp> {
79499ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
79527cc31b6SNicolas Vasilache 
79699ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
79799ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
79899ef9eebSMatthias Springer     VectorType castSrcType = bitcastOp.getSourceVectorType();
79999ef9eebSMatthias Springer     VectorType castDstType = bitcastOp.getResultVectorType();
80099ef9eebSMatthias Springer     assert(castSrcType.getRank() == castDstType.getRank());
80128f9bfe4SXiang Li     // Skip 0-D vector which will not from InsertStridedSliceOp.
80228f9bfe4SXiang Li     if (castSrcType.getRank() == 0)
80328f9bfe4SXiang Li       return failure();
80499ef9eebSMatthias Springer 
80599ef9eebSMatthias Springer     int64_t castSrcLastDim = castSrcType.getShape().back();
80699ef9eebSMatthias Springer     int64_t castDstLastDim = castDstType.getShape().back();
80799ef9eebSMatthias Springer     // Require casting to less elements for now; other cases to be implemented.
80899ef9eebSMatthias Springer     if (castSrcLastDim < castDstLastDim)
80999ef9eebSMatthias Springer       return failure();
81099ef9eebSMatthias Springer 
81199ef9eebSMatthias Springer     assert(castSrcLastDim % castDstLastDim == 0);
81299ef9eebSMatthias Springer     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
81399ef9eebSMatthias Springer 
81499ef9eebSMatthias Springer     auto insertOp =
8157c38fd60SJacques Pienaar         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
81699ef9eebSMatthias Springer     if (!insertOp)
81799ef9eebSMatthias Springer       return failure();
81899ef9eebSMatthias Springer 
81999ef9eebSMatthias Springer     // Only accept all one strides for now.
8207c38fd60SJacques Pienaar     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
8219e5d2495SKazu Hirata                      [](const APInt &val) { return !val.isOne(); }))
82299ef9eebSMatthias Springer       return failure();
82399ef9eebSMatthias Springer 
82499ef9eebSMatthias Springer     unsigned rank = insertOp.getSourceVectorType().getRank();
82599ef9eebSMatthias Springer     // Require insert op to have the same rank for the source and destination
82699ef9eebSMatthias Springer     // vector; other cases to be implemented.
82799ef9eebSMatthias Springer     if (rank != insertOp.getDestVectorType().getRank())
82899ef9eebSMatthias Springer       return failure();
82999ef9eebSMatthias Springer 
830dc26c030Sstanley-nod     // Requires that shape of insert op src is castable to dstType.
831dc26c030Sstanley-nod     unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
832dc26c030Sstanley-nod     unsigned destinationWidth =
833dc26c030Sstanley-nod         castDstType.getElementType().getIntOrFloatBitWidth();
834dc26c030Sstanley-nod     unsigned numElements = destinationWidth / sourceWidth;
835dc26c030Sstanley-nod     if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
836dc26c030Sstanley-nod       return failure();
837dc26c030Sstanley-nod 
8387c38fd60SJacques Pienaar     ArrayAttr newOffsets = insertOp.getOffsets();
83999ef9eebSMatthias Springer     assert(newOffsets.size() == rank);
8407a69a9d7SNicolas Vasilache     SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
84199ef9eebSMatthias Springer     if (offsets.back() % shrinkRatio != 0)
84299ef9eebSMatthias Springer       return failure();
84399ef9eebSMatthias Springer     offsets.back() = offsets.back() / shrinkRatio;
84499ef9eebSMatthias Springer     newOffsets = rewriter.getI64ArrayAttr(offsets);
84599ef9eebSMatthias Springer 
8467a69a9d7SNicolas Vasilache     SmallVector<int64_t> srcDims =
84799ef9eebSMatthias Springer         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
84899ef9eebSMatthias Springer     srcDims.back() = srcDims.back() / shrinkRatio;
84999ef9eebSMatthias Springer     VectorType newCastSrcType =
85099ef9eebSMatthias Springer         VectorType::get(srcDims, castDstType.getElementType());
85199ef9eebSMatthias Springer 
85299ef9eebSMatthias Springer     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
8537c38fd60SJacques Pienaar         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
85499ef9eebSMatthias Springer 
8557a69a9d7SNicolas Vasilache     SmallVector<int64_t> dstDims =
85699ef9eebSMatthias Springer         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
85799ef9eebSMatthias Springer     dstDims.back() = dstDims.back() / shrinkRatio;
85899ef9eebSMatthias Springer     VectorType newCastDstType =
85999ef9eebSMatthias Springer         VectorType::get(dstDims, castDstType.getElementType());
86099ef9eebSMatthias Springer 
86199ef9eebSMatthias Springer     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
8627c38fd60SJacques Pienaar         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
86399ef9eebSMatthias Springer 
86499ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
86599ef9eebSMatthias Springer         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
8667c38fd60SJacques Pienaar         insertOp.getStrides());
86799ef9eebSMatthias Springer 
86899ef9eebSMatthias Springer     return success();
86999ef9eebSMatthias Springer   }
87099ef9eebSMatthias Springer };
87199ef9eebSMatthias Springer 
872650f04feSQuinn Dawkins // Breaks down vector.bitcast op
873650f04feSQuinn Dawkins //
874650f04feSQuinn Dawkins // This transforms IR like:
875650f04feSQuinn Dawkins //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
876650f04feSQuinn Dawkins // Into:
877650f04feSQuinn Dawkins //   %cst = vector.splat %c0_f32 : vector<4xf32>
878650f04feSQuinn Dawkins //   %1 = vector.extract_strided_slice %0 {
879650f04feSQuinn Dawkins //          offsets = [0], sizes = [4], strides = [1]
880650f04feSQuinn Dawkins //        } : vector<8xf16> to vector<4xf16>
881650f04feSQuinn Dawkins //   %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
882650f04feSQuinn Dawkins //   %4 = vector.insert_strided_slice %2, %cst {
883650f04feSQuinn Dawkins //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
884650f04feSQuinn Dawkins //   %5 = vector.extract_strided_slice %0 {
885650f04feSQuinn Dawkins //          offsets = [4], sizes = [4], strides = [1]
886650f04feSQuinn Dawkins //        } : vector<8xf16> to vector<4xf16>
887650f04feSQuinn Dawkins //   %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
888650f04feSQuinn Dawkins //   %7 = vector.insert_strided_slice %6, %cst {
889650f04feSQuinn Dawkins //          offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
890650f04feSQuinn Dawkins struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
891650f04feSQuinn Dawkins   using OpRewritePattern::OpRewritePattern;
892650f04feSQuinn Dawkins 
893650f04feSQuinn Dawkins public:
894650f04feSQuinn Dawkins   BreakDownVectorBitCast(MLIRContext *context,
895650f04feSQuinn Dawkins                          std::function<bool(vector::BitCastOp)> controlFn,
896650f04feSQuinn Dawkins                          PatternBenefit benefit)
897650f04feSQuinn Dawkins       : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
898650f04feSQuinn Dawkins 
899650f04feSQuinn Dawkins   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
900650f04feSQuinn Dawkins                                 PatternRewriter &rewriter) const override {
901650f04feSQuinn Dawkins 
902650f04feSQuinn Dawkins     if (controlFn && !controlFn(bitcastOp))
903650f04feSQuinn Dawkins       return failure();
904650f04feSQuinn Dawkins 
905650f04feSQuinn Dawkins     VectorType castSrcType = bitcastOp.getSourceVectorType();
906650f04feSQuinn Dawkins     VectorType castDstType = bitcastOp.getResultVectorType();
907650f04feSQuinn Dawkins     assert(castSrcType.getRank() == castDstType.getRank());
908650f04feSQuinn Dawkins 
909d88293d8SAndrzej Warzyński     // This transformation builds on top of
910d88293d8SAndrzej Warzyński     // vector.{extract|insert}_strided_slice, which do not support
911d88293d8SAndrzej Warzyński     // extracting/inserting "scallable sub-vectors". Bail out.
912d88293d8SAndrzej Warzyński     if (castSrcType.isScalable())
913d88293d8SAndrzej Warzyński       return rewriter.notifyMatchFailure(bitcastOp,
914d88293d8SAndrzej Warzyński                                          "Scalable vectors are not supported");
915d88293d8SAndrzej Warzyński 
916650f04feSQuinn Dawkins     // Only support rank 1 case for now.
917650f04feSQuinn Dawkins     if (castSrcType.getRank() != 1)
918650f04feSQuinn Dawkins       return failure();
919650f04feSQuinn Dawkins 
920650f04feSQuinn Dawkins     int64_t castSrcLastDim = castSrcType.getShape().back();
921650f04feSQuinn Dawkins     int64_t castDstLastDim = castDstType.getShape().back();
922650f04feSQuinn Dawkins     // Require casting to less elements for now; other cases to be implemented.
923650f04feSQuinn Dawkins     if (castSrcLastDim < castDstLastDim)
924650f04feSQuinn Dawkins       return failure();
925650f04feSQuinn Dawkins 
926650f04feSQuinn Dawkins     assert(castSrcLastDim % castDstLastDim == 0);
927650f04feSQuinn Dawkins     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
928650f04feSQuinn Dawkins     // Nothing to do if it is already bitcasting to a single element.
929650f04feSQuinn Dawkins     if (castSrcLastDim == shrinkRatio)
930650f04feSQuinn Dawkins       return failure();
931650f04feSQuinn Dawkins 
932650f04feSQuinn Dawkins     Location loc = bitcastOp.getLoc();
933650f04feSQuinn Dawkins     Type elemType = castDstType.getElementType();
934650f04feSQuinn Dawkins     assert(elemType.isSignlessIntOrIndexOrFloat());
935650f04feSQuinn Dawkins 
936650f04feSQuinn Dawkins     Value zero = rewriter.create<arith::ConstantOp>(
937650f04feSQuinn Dawkins         loc, elemType, rewriter.getZeroAttr(elemType));
938650f04feSQuinn Dawkins     Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
939650f04feSQuinn Dawkins 
9409cbc1f29SHan-Chung Wang     SmallVector<int64_t> sliceShape = {castDstLastDim};
9419cbc1f29SHan-Chung Wang     SmallVector<int64_t> strides = {1};
942650f04feSQuinn Dawkins     VectorType newCastDstType =
943650f04feSQuinn Dawkins         VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
944650f04feSQuinn Dawkins                         castDstType.getElementType());
945650f04feSQuinn Dawkins 
946650f04feSQuinn Dawkins     for (int i = 0, e = shrinkRatio; i < e; ++i) {
947650f04feSQuinn Dawkins       Value extracted = rewriter.create<ExtractStridedSliceOp>(
948650f04feSQuinn Dawkins           loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
949650f04feSQuinn Dawkins           sliceShape, strides);
950650f04feSQuinn Dawkins       Value bitcast =
951650f04feSQuinn Dawkins           rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
952650f04feSQuinn Dawkins       res = rewriter.create<InsertStridedSliceOp>(
953650f04feSQuinn Dawkins           loc, bitcast, res,
954650f04feSQuinn Dawkins           ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
955650f04feSQuinn Dawkins     }
956650f04feSQuinn Dawkins     rewriter.replaceOp(bitcastOp, res);
957650f04feSQuinn Dawkins     return success();
958650f04feSQuinn Dawkins   }
959650f04feSQuinn Dawkins 
960650f04feSQuinn Dawkins private:
961650f04feSQuinn Dawkins   std::function<bool(BitCastOp)> controlFn;
962650f04feSQuinn Dawkins };
963650f04feSQuinn Dawkins 
96459fbba94SAndrzej Warzyński /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
9654d339ec9SAndrzej Warzynski /// ```
9664d339ec9SAndrzej Warzynski /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
9674d339ec9SAndrzej Warzynski /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
9684d339ec9SAndrzej Warzynski /// %r = arith.addi %a, %b : vector<1x4xindex>
9694d339ec9SAndrzej Warzynski /// ```
9704d339ec9SAndrzej Warzynski /// Gets converted to:
9714d339ec9SAndrzej Warzynski /// ```
9724d339ec9SAndrzej Warzynski /// %r = arith.addi %arg0, %arg1 : index
9734d339ec9SAndrzej Warzynski /// %b = vector.broadcast %r : index to vector<1x4xindex>
9744d339ec9SAndrzej Warzynski /// ```
97559fbba94SAndrzej Warzyński ///
97659fbba94SAndrzej Warzyński /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
97759fbba94SAndrzej Warzyński /// ops.
9784d339ec9SAndrzej Warzynski struct ReorderElementwiseOpsOnBroadcast final
9794d339ec9SAndrzej Warzynski     : public OpTraitRewritePattern<OpTrait::Elementwise> {
9804d339ec9SAndrzej Warzynski   using OpTraitRewritePattern::OpTraitRewritePattern;
9814d339ec9SAndrzej Warzynski   LogicalResult matchAndRewrite(Operation *op,
9824d339ec9SAndrzej Warzynski                                 PatternRewriter &rewriter) const override {
9834d339ec9SAndrzej Warzynski     if (op->getNumResults() != 1)
9844d339ec9SAndrzej Warzynski       return failure();
9854d339ec9SAndrzej Warzynski     if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
9864d339ec9SAndrzej Warzynski       return failure();
9874d339ec9SAndrzej Warzynski     if (!OpTrait::hasElementwiseMappableTraits(op))
988efe3db21SAndrzej Warzyński       return rewriter.notifyMatchFailure(
989efe3db21SAndrzej Warzyński           op, "Op doesn't have ElementwiseMappableTraits");
990efe3db21SAndrzej Warzyński     if (op->getNumOperands() == 0)
9914d339ec9SAndrzej Warzynski       return failure();
992efe3db21SAndrzej Warzyński     if (op->getResults()[0].getType() != op->getOperand(0).getType())
993efe3db21SAndrzej Warzyński       return rewriter.notifyMatchFailure(op,
994efe3db21SAndrzej Warzyński                                          "result and operand type mismatch");
995f28f09dcSMaheshRavishankar     if (isa<vector::FMAOp>(op)) {
996efe3db21SAndrzej Warzyński       return rewriter.notifyMatchFailure(
997efe3db21SAndrzej Warzyński           op,
998efe3db21SAndrzej Warzyński           "Op only accepts vector types - not supported as broadcast source "
999efe3db21SAndrzej Warzyński           "might be a scalar");
1000f28f09dcSMaheshRavishankar     }
10014d339ec9SAndrzej Warzynski 
100259fbba94SAndrzej Warzyński     // Get the type of the lhs operand
100359fbba94SAndrzej Warzyński     auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
100459fbba94SAndrzej Warzyński     if (!lhsBcastOrSplat ||
100559fbba94SAndrzej Warzyński         !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
10064d339ec9SAndrzej Warzynski       return failure();
100759fbba94SAndrzej Warzyński     auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
10084d339ec9SAndrzej Warzynski 
100959fbba94SAndrzej Warzyński     // Make sure that all operands are broadcast from identical types:
101059fbba94SAndrzej Warzyński     //  * scalar (`vector.broadcast` + `vector.splat`), or
101159fbba94SAndrzej Warzyński     //  * vector (`vector.broadcast`).
101259fbba94SAndrzej Warzyński     // Otherwise the re-ordering wouldn't be safe.
101359fbba94SAndrzej Warzyński     if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
10144d339ec9SAndrzej Warzynski           auto bcast = val.getDefiningOp<vector::BroadcastOp>();
101559fbba94SAndrzej Warzyński           if (bcast)
101659fbba94SAndrzej Warzyński             return (bcast.getOperand().getType() == lhsBcastOrSplatType);
101759fbba94SAndrzej Warzyński           auto splat = val.getDefiningOp<vector::SplatOp>();
101859fbba94SAndrzej Warzyński           if (splat)
101959fbba94SAndrzej Warzyński             return (splat.getOperand().getType() == lhsBcastOrSplatType);
102059fbba94SAndrzej Warzyński           return false;
10214d339ec9SAndrzej Warzynski         })) {
10224d339ec9SAndrzej Warzynski       return failure();
10234d339ec9SAndrzej Warzynski     }
10244d339ec9SAndrzej Warzynski 
102559fbba94SAndrzej Warzyński     // Collect the source values before broadcasting
10264d339ec9SAndrzej Warzynski     SmallVector<Value> srcValues;
10274d339ec9SAndrzej Warzynski     srcValues.reserve(op->getNumOperands());
10284d339ec9SAndrzej Warzynski     for (Value operand : op->getOperands()) {
102959fbba94SAndrzej Warzyński       srcValues.push_back(operand.getDefiningOp()->getOperand(0));
10304d339ec9SAndrzej Warzynski     }
10314d339ec9SAndrzej Warzynski 
103259fbba94SAndrzej Warzyński     // Create the "elementwise" Op
10334d339ec9SAndrzej Warzynski     Operation *elementwiseOp =
10344d339ec9SAndrzej Warzynski         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
103559fbba94SAndrzej Warzyński                         lhsBcastOrSplatType, op->getAttrs());
10364d339ec9SAndrzej Warzynski 
103759fbba94SAndrzej Warzyński     // Replace the original Op with the elementwise Op
10384d339ec9SAndrzej Warzynski     auto vectorType = op->getResultTypes()[0];
10394d339ec9SAndrzej Warzynski     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
10404d339ec9SAndrzej Warzynski         op, vectorType, elementwiseOp->getResults());
10414d339ec9SAndrzej Warzynski 
10424d339ec9SAndrzej Warzynski     return success();
10434d339ec9SAndrzej Warzynski   }
10444d339ec9SAndrzej Warzynski };
10454d339ec9SAndrzej Warzynski 
104699ef9eebSMatthias Springer // Helper that returns a vector comparison that constructs a mask:
104799ef9eebSMatthias Springer //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
104899ef9eebSMatthias Springer //
104999ef9eebSMatthias Springer // If `dim == 0` then the result will be a 0-D vector.
105099ef9eebSMatthias Springer //
105199ef9eebSMatthias Springer // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
105299ef9eebSMatthias Springer //       much more compact, IR for this operation, but LLVM eventually
105399ef9eebSMatthias Springer //       generates more elaborate instructions for this intrinsic since it
105499ef9eebSMatthias Springer //       is very conservative on the boundary conditions.
105599ef9eebSMatthias Springer static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
10567bc8ad51SJavier Setoain                                    bool force32BitVectorIndices, int64_t dim,
105799ef9eebSMatthias Springer                                    Value b, Value *off = nullptr) {
105899ef9eebSMatthias Springer   auto loc = op->getLoc();
105999ef9eebSMatthias Springer   // If we can assume all indices fit in 32-bit, we perform the vector
106099ef9eebSMatthias Springer   // comparison in 32-bit to get a higher degree of SIMD parallelism.
106199ef9eebSMatthias Springer   // Otherwise we perform the vector comparison using 64-bit indices.
106299ef9eebSMatthias Springer   Type idxType =
10637bc8ad51SJavier Setoain       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
106499ef9eebSMatthias Springer   DenseIntElementsAttr indicesAttr;
10657bc8ad51SJavier Setoain   if (dim == 0 && force32BitVectorIndices) {
106699ef9eebSMatthias Springer     indicesAttr = DenseIntElementsAttr::get(
106799ef9eebSMatthias Springer         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
106899ef9eebSMatthias Springer   } else if (dim == 0) {
106999ef9eebSMatthias Springer     indicesAttr = DenseIntElementsAttr::get(
107099ef9eebSMatthias Springer         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
10717bc8ad51SJavier Setoain   } else if (force32BitVectorIndices) {
107299ef9eebSMatthias Springer     indicesAttr = rewriter.getI32VectorAttr(
107399ef9eebSMatthias Springer         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
107499ef9eebSMatthias Springer   } else {
107599ef9eebSMatthias Springer     indicesAttr = rewriter.getI64VectorAttr(
107699ef9eebSMatthias Springer         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
107799ef9eebSMatthias Springer   }
107899ef9eebSMatthias Springer   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
107999ef9eebSMatthias Springer   // Add in an offset if requested.
108099ef9eebSMatthias Springer   if (off) {
1081a75a46dbSJavier Setoain     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
10826a8ba318SRiver Riddle     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
108399ef9eebSMatthias Springer     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
108499ef9eebSMatthias Springer   }
108599ef9eebSMatthias Springer   // Construct the vector comparison.
1086a75a46dbSJavier Setoain   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
10876a8ba318SRiver Riddle   Value bounds =
10886a8ba318SRiver Riddle       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
108999ef9eebSMatthias Springer   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
109099ef9eebSMatthias Springer                                         bounds);
109199ef9eebSMatthias Springer }
109299ef9eebSMatthias Springer 
109399ef9eebSMatthias Springer template <typename ConcreteOp>
109499ef9eebSMatthias Springer struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
109599ef9eebSMatthias Springer public:
109627cc31b6SNicolas Vasilache   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
109727cc31b6SNicolas Vasilache                                    PatternBenefit benefit = 1)
109827cc31b6SNicolas Vasilache       : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
10997bc8ad51SJavier Setoain         force32BitVectorIndices(enableIndexOpt) {}
110099ef9eebSMatthias Springer 
110199ef9eebSMatthias Springer   LogicalResult matchAndRewrite(ConcreteOp xferOp,
110299ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
110399ef9eebSMatthias Springer     if (!xferOp.hasOutOfBoundsDim())
110499ef9eebSMatthias Springer       return failure();
110599ef9eebSMatthias Springer 
1106be650de5SKazu Hirata     if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
110799ef9eebSMatthias Springer       return failure();
110899ef9eebSMatthias Springer 
110999ef9eebSMatthias Springer     Location loc = xferOp->getLoc();
111099ef9eebSMatthias Springer     VectorType vtp = xferOp.getVectorType();
111199ef9eebSMatthias Springer 
1112f2b89c7aSJavier Setoain     // Create the in-bounds mask with all elements between [0 .. dim - offset)
1113f2b89c7aSJavier Setoain     // set and [dim - offset .. vector_length) unset.
111499ef9eebSMatthias Springer     //
111599ef9eebSMatthias Springer     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
111699ef9eebSMatthias Springer     //       dimensions here.
11177c38fd60SJacques Pienaar     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
11187c38fd60SJacques Pienaar     Value off = xferOp.getIndices()[lastIndex];
111999ef9eebSMatthias Springer     Value dim =
11207c38fd60SJacques Pienaar         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1121f2b89c7aSJavier Setoain     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1122f2b89c7aSJavier Setoain     Value mask = rewriter.create<vector::CreateMaskOp>(
1123f2b89c7aSJavier Setoain         loc,
1124f2b89c7aSJavier Setoain         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1125f22af204SAndrzej Warzynski                         vtp.getScalableDims()),
1126f2b89c7aSJavier Setoain         b);
11277c38fd60SJacques Pienaar     if (xferOp.getMask()) {
112899ef9eebSMatthias Springer       // Intersect the in-bounds with the mask specified as an op parameter.
11297c38fd60SJacques Pienaar       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
113099ef9eebSMatthias Springer     }
113199ef9eebSMatthias Springer 
11325fcf907bSMatthias Springer     rewriter.modifyOpInPlace(xferOp, [&]() {
11337c38fd60SJacques Pienaar       xferOp.getMaskMutable().assign(mask);
11347c38fd60SJacques Pienaar       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
113599ef9eebSMatthias Springer     });
113699ef9eebSMatthias Springer 
113799ef9eebSMatthias Springer     return success();
113899ef9eebSMatthias Springer   }
113999ef9eebSMatthias Springer 
114099ef9eebSMatthias Springer private:
11417bc8ad51SJavier Setoain   const bool force32BitVectorIndices;
114299ef9eebSMatthias Springer };
114399ef9eebSMatthias Springer 
114499ef9eebSMatthias Springer /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
114599ef9eebSMatthias Springer class VectorCreateMaskOpConversion
114699ef9eebSMatthias Springer     : public OpRewritePattern<vector::CreateMaskOp> {
114799ef9eebSMatthias Springer public:
114899ef9eebSMatthias Springer   explicit VectorCreateMaskOpConversion(MLIRContext *context,
114927cc31b6SNicolas Vasilache                                         bool enableIndexOpt,
115027cc31b6SNicolas Vasilache                                         PatternBenefit benefit = 1)
115127cc31b6SNicolas Vasilache       : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
11527bc8ad51SJavier Setoain         force32BitVectorIndices(enableIndexOpt) {}
115399ef9eebSMatthias Springer 
115499ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
115599ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
115699ef9eebSMatthias Springer     auto dstType = op.getType();
11575550c821STres Popp     if (cast<VectorType>(dstType).isScalable())
1158a75a46dbSJavier Setoain       return failure();
115999ef9eebSMatthias Springer     int64_t rank = dstType.getRank();
116099ef9eebSMatthias Springer     if (rank > 1)
116199ef9eebSMatthias Springer       return failure();
116299ef9eebSMatthias Springer     rewriter.replaceOp(
11637bc8ad51SJavier Setoain         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
116499ef9eebSMatthias Springer                                   rank == 0 ? 0 : dstType.getDimSize(0),
116599ef9eebSMatthias Springer                                   op.getOperand(0)));
116699ef9eebSMatthias Springer     return success();
116799ef9eebSMatthias Springer   }
116899ef9eebSMatthias Springer 
116999ef9eebSMatthias Springer private:
11707bc8ad51SJavier Setoain   const bool force32BitVectorIndices;
117199ef9eebSMatthias Springer };
117299ef9eebSMatthias Springer 
117315a08cf2SDiego Caballero /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
117415a08cf2SDiego Caballero static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
117515a08cf2SDiego Caballero   auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
117615a08cf2SDiego Caballero   // TODO: Support non-dense constant.
117715a08cf2SDiego Caballero   if (!denseAttr)
117815a08cf2SDiego Caballero     return false;
117915a08cf2SDiego Caballero 
118015a08cf2SDiego Caballero   assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
118115a08cf2SDiego Caballero   return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
118215a08cf2SDiego Caballero }
118315a08cf2SDiego Caballero 
118415a08cf2SDiego Caballero /// Folds a select operation between an all-true and all-false vector. For now,
118515a08cf2SDiego Caballero /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
118615a08cf2SDiego Caballero ///
118715a08cf2SDiego Caballero ///   %true = arith.constant dense<true> : vector<1xi1>
118815a08cf2SDiego Caballero ///   %false = arith.constant dense<false> : vector<1xi1>
118915a08cf2SDiego Caballero ///   %result = arith.select %cond, %true, %false : i1, vector<1xi1>
119015a08cf2SDiego Caballero ///   =>
119115a08cf2SDiego Caballero ///   %result = vector.broadcast %cond : i1 to vector<1xi1>
119215a08cf2SDiego Caballero ///
119315a08cf2SDiego Caballero /// InstCombine seems to handle vectors with multiple elements but not the
119415a08cf2SDiego Caballero /// single element ones.
119515a08cf2SDiego Caballero struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
119615a08cf2SDiego Caballero   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
119715a08cf2SDiego Caballero 
119815a08cf2SDiego Caballero   LogicalResult matchAndRewrite(arith::SelectOp selectOp,
119915a08cf2SDiego Caballero                                 PatternRewriter &rewriter) const override {
120015a08cf2SDiego Caballero     auto vecType = dyn_cast<VectorType>(selectOp.getType());
120115a08cf2SDiego Caballero     if (!vecType || !vecType.getElementType().isInteger(1))
120215a08cf2SDiego Caballero       return failure();
120315a08cf2SDiego Caballero 
120415a08cf2SDiego Caballero     // Only scalar conditions can be folded.
120515a08cf2SDiego Caballero     Value cond = selectOp.getCondition();
120615a08cf2SDiego Caballero     if (isa<VectorType>(cond.getType()))
120715a08cf2SDiego Caballero       return failure();
120815a08cf2SDiego Caballero 
120915a08cf2SDiego Caballero     // TODO: Support n-D and scalable vectors.
121015a08cf2SDiego Caballero     if (vecType.getRank() != 1 || vecType.isScalable())
121115a08cf2SDiego Caballero       return failure();
121215a08cf2SDiego Caballero 
121315a08cf2SDiego Caballero     // TODO: Support vectors with multiple elements.
121415a08cf2SDiego Caballero     if (vecType.getShape()[0] != 1)
121515a08cf2SDiego Caballero       return failure();
121615a08cf2SDiego Caballero 
121715a08cf2SDiego Caballero     auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
121815a08cf2SDiego Caballero     if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
121915a08cf2SDiego Caballero       return failure();
122015a08cf2SDiego Caballero 
122115a08cf2SDiego Caballero     auto falseConst =
122215a08cf2SDiego Caballero         selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
122315a08cf2SDiego Caballero     if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
122415a08cf2SDiego Caballero       return failure();
122515a08cf2SDiego Caballero 
122615a08cf2SDiego Caballero     // Replace select with its condition broadcasted to single element vector.
122715a08cf2SDiego Caballero     auto elemType = rewriter.getIntegerType(vecType.getNumElements());
122815a08cf2SDiego Caballero     auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
122915a08cf2SDiego Caballero     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
123015a08cf2SDiego Caballero     return success();
123115a08cf2SDiego Caballero   }
123215a08cf2SDiego Caballero };
123315a08cf2SDiego Caballero 
123412b676deSHan-Chung Wang /// Returns the number of dims can be folded away from transfer ops. It returns
123512b676deSHan-Chung Wang /// a failure if it can not determine the number of dims to be folded.
1236c65fb32dSAndrzej Warzyński ///
1237c65fb32dSAndrzej Warzyński /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1238c65fb32dSAndrzej Warzyński /// `vectorType` is vector<16x16x1x1xf32>
1239c65fb32dSAndrzej Warzyński /// (there two inner most dims can be dropped by memref.subview ops)
1240c65fb32dSAndrzej Warzyński ///
1241c65fb32dSAndrzej Warzyński /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1242c65fb32dSAndrzej Warzyński /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1243c65fb32dSAndrzej Warzyński /// (only the inner most unit dim of `srcType` can be dropped)
1244c65fb32dSAndrzej Warzyński ///
1245c65fb32dSAndrzej Warzyński /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1246c65fb32dSAndrzej Warzyński /// `vectorType` is vector<16x16x1x[1]xf32>
1247c65fb32dSAndrzej Warzyński /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1248c65fb32dSAndrzej Warzyński /// unit")
124912b676deSHan-Chung Wang static FailureOr<size_t>
125012b676deSHan-Chung Wang getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
125112b676deSHan-Chung Wang   SmallVector<int64_t> srcStrides;
125212b676deSHan-Chung Wang   int64_t srcOffset;
12536aaa8f25SMatthias Springer   if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
125412b676deSHan-Chung Wang     return failure();
125512b676deSHan-Chung Wang 
1256e01ff823SBenjamin Maxwell   auto isUnitDim = [](VectorType type, int dim) {
1257e01ff823SBenjamin Maxwell     return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1258e01ff823SBenjamin Maxwell   };
1259e01ff823SBenjamin Maxwell 
126012b676deSHan-Chung Wang   // According to vector.transfer_read/write semantics, the vector can be a
126112b676deSHan-Chung Wang   // slice. Thus, we have to offset the check index with `rankDiff` in
126212b676deSHan-Chung Wang   // `srcStrides` and source dim sizes.
126312b676deSHan-Chung Wang   size_t result = 0;
126412b676deSHan-Chung Wang   int rankDiff = srcType.getRank() - vectorType.getRank();
126512b676deSHan-Chung Wang   for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
126612b676deSHan-Chung Wang     // Check that the inner dim size is 1 for both memref type and vector slice.
126712b676deSHan-Chung Wang     // It can be folded only if they are 1 and the stride is 1.
126812b676deSHan-Chung Wang     int dim = vectorType.getRank() - i - 1;
126912b676deSHan-Chung Wang     if (srcStrides[dim + rankDiff] != 1 ||
1270e01ff823SBenjamin Maxwell         srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
127112b676deSHan-Chung Wang       break;
127212b676deSHan-Chung Wang     result++;
127312b676deSHan-Chung Wang   }
127412b676deSHan-Chung Wang   return result;
127512b676deSHan-Chung Wang }
127612b676deSHan-Chung Wang 
127712b676deSHan-Chung Wang /// Drop inner most contiguous unit dimensions from transfer_read operand.
127812b676deSHan-Chung Wang class DropInnerMostUnitDimsTransferRead
127912b676deSHan-Chung Wang     : public OpRewritePattern<vector::TransferReadOp> {
128027cc31b6SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
128199ef9eebSMatthias Springer 
128299ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
128399ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
128499ef9eebSMatthias Springer     // TODO: support 0-d corner case.
128599ef9eebSMatthias Springer     if (readOp.getTransferRank() == 0)
128699ef9eebSMatthias Springer       return failure();
128799ef9eebSMatthias Springer 
128899ef9eebSMatthias Springer     // TODO: support mask.
12897c38fd60SJacques Pienaar     if (readOp.getMask())
129099ef9eebSMatthias Springer       return failure();
129199ef9eebSMatthias Springer 
12925550c821STres Popp     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1293d592c8ecSDiego Caballero     if (!srcType)
129499ef9eebSMatthias Springer       return failure();
129599ef9eebSMatthias Springer 
12967c38fd60SJacques Pienaar     if (!readOp.getPermutationMap().isMinorIdentity())
129799ef9eebSMatthias Springer       return failure();
129899ef9eebSMatthias Springer 
129999ef9eebSMatthias Springer     auto targetType = readOp.getVectorType();
130099ef9eebSMatthias Springer     if (targetType.getRank() <= 1)
130199ef9eebSMatthias Springer       return failure();
130299ef9eebSMatthias Springer 
130312b676deSHan-Chung Wang     FailureOr<size_t> maybeDimsToDrop =
130412b676deSHan-Chung Wang         getTransferFoldableInnerUnitDims(srcType, targetType);
130512b676deSHan-Chung Wang     if (failed(maybeDimsToDrop))
130699ef9eebSMatthias Springer       return failure();
130799ef9eebSMatthias Springer 
130812b676deSHan-Chung Wang     size_t dimsToDrop = maybeDimsToDrop.value();
130999ef9eebSMatthias Springer     if (dimsToDrop == 0)
131099ef9eebSMatthias Springer       return failure();
131199ef9eebSMatthias Springer 
13126479a5a4SAndrzej Warzyński     auto inBounds = readOp.getInBoundsValues();
13136479a5a4SAndrzej Warzyński     auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
13146479a5a4SAndrzej Warzyński     if (llvm::is_contained(droppedInBounds, false))
131577db8b08SAndrzej Warzyński       return failure();
131677db8b08SAndrzej Warzyński 
131799ef9eebSMatthias Springer     auto resultTargetVecType =
131899ef9eebSMatthias Springer         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1319e01ff823SBenjamin Maxwell                         targetType.getElementType(),
1320e01ff823SBenjamin Maxwell                         targetType.getScalableDims().drop_back(dimsToDrop));
132199ef9eebSMatthias Springer 
132299ef9eebSMatthias Springer     auto loc = readOp.getLoc();
1323d592c8ecSDiego Caballero     SmallVector<OpFoldResult> sizes =
1324d592c8ecSDiego Caballero         memref::getMixedSizes(rewriter, loc, readOp.getSource());
1325d592c8ecSDiego Caballero     SmallVector<OpFoldResult> offsets(srcType.getRank(),
1326d592c8ecSDiego Caballero                                       rewriter.getIndexAttr(0));
1327d592c8ecSDiego Caballero     SmallVector<OpFoldResult> strides(srcType.getRank(),
1328d592c8ecSDiego Caballero                                       rewriter.getIndexAttr(1));
13297c83d1bdSHan-Chung Wang     auto resultMemrefType =
13307c83d1bdSHan-Chung Wang         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
13317c83d1bdSHan-Chung Wang             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
13327c83d1bdSHan-Chung Wang             strides));
13331f5807ebSAndrzej Warzyński     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
13341f5807ebSAndrzej Warzyński         readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
133599ef9eebSMatthias Springer     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1336d592c8ecSDiego Caballero         loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
133799ef9eebSMatthias Springer     auto permMap = getTransferMinorIdentityMap(
13385550c821STres Popp         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
133999ef9eebSMatthias Springer     Value result = rewriter.create<vector::TransferReadOp>(
134099ef9eebSMatthias Springer         loc, resultTargetVecType, rankedReducedView,
13417c38fd60SJacques Pienaar         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
13427c38fd60SJacques Pienaar         readOp.getPadding(),
134399ef9eebSMatthias Springer         // TODO: support mask.
134499ef9eebSMatthias Springer         /*mask=*/Value(), inBoundsAttr);
134599ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
134699ef9eebSMatthias Springer                                                      result);
134799ef9eebSMatthias Springer     return success();
134899ef9eebSMatthias Springer   }
134999ef9eebSMatthias Springer };
135099ef9eebSMatthias Springer 
135112b676deSHan-Chung Wang /// Drop inner most contiguous unit dimensions from transfer_write operand.
135212b676deSHan-Chung Wang /// E.g.,
135312b676deSHan-Chung Wang ///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
135412b676deSHan-Chung Wang ///      {in_bounds = [true, true, true, true, true]}
135512b676deSHan-Chung Wang ///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
135612b676deSHan-Chung Wang ///
135712b676deSHan-Chung Wang /// will be replaced with
135812b676deSHan-Chung Wang ///
135912b676deSHan-Chung Wang ///    %subview = memref.subview %arg0
136012b676deSHan-Chung Wang ///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
136112b676deSHan-Chung Wang ///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
136212b676deSHan-Chung Wang ///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
136312b676deSHan-Chung Wang ///      to vector<1x16x16xf32>
136412b676deSHan-Chung Wang ///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
136512b676deSHan-Chung Wang ///      {in_bounds = [true, true, true]}
136612b676deSHan-Chung Wang ///      : vector<1x16x16xf32>, memref<1x512x16xf32>
1367c65fb32dSAndrzej Warzyński ///
1368c65fb32dSAndrzej Warzyński /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
136912b676deSHan-Chung Wang class DropInnerMostUnitDimsTransferWrite
137012b676deSHan-Chung Wang     : public OpRewritePattern<vector::TransferWriteOp> {
137112b676deSHan-Chung Wang   using OpRewritePattern::OpRewritePattern;
137212b676deSHan-Chung Wang 
137312b676deSHan-Chung Wang   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
137412b676deSHan-Chung Wang                                 PatternRewriter &rewriter) const override {
137512b676deSHan-Chung Wang     // TODO: support 0-d corner case.
137612b676deSHan-Chung Wang     if (writeOp.getTransferRank() == 0)
137712b676deSHan-Chung Wang       return failure();
137812b676deSHan-Chung Wang 
137912b676deSHan-Chung Wang     // TODO: support mask.
138012b676deSHan-Chung Wang     if (writeOp.getMask())
138112b676deSHan-Chung Wang       return failure();
138212b676deSHan-Chung Wang 
138312b676deSHan-Chung Wang     auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1384d193ac4fSHan-Chung Wang     if (!srcType)
138512b676deSHan-Chung Wang       return failure();
138612b676deSHan-Chung Wang 
138712b676deSHan-Chung Wang     if (!writeOp.getPermutationMap().isMinorIdentity())
138812b676deSHan-Chung Wang       return failure();
138912b676deSHan-Chung Wang 
139012b676deSHan-Chung Wang     auto targetType = writeOp.getVectorType();
139112b676deSHan-Chung Wang     if (targetType.getRank() <= 1)
139212b676deSHan-Chung Wang       return failure();
139312b676deSHan-Chung Wang 
139412b676deSHan-Chung Wang     FailureOr<size_t> maybeDimsToDrop =
139512b676deSHan-Chung Wang         getTransferFoldableInnerUnitDims(srcType, targetType);
139612b676deSHan-Chung Wang     if (failed(maybeDimsToDrop))
139712b676deSHan-Chung Wang       return failure();
139812b676deSHan-Chung Wang 
139912b676deSHan-Chung Wang     size_t dimsToDrop = maybeDimsToDrop.value();
140012b676deSHan-Chung Wang     if (dimsToDrop == 0)
140112b676deSHan-Chung Wang       return failure();
140212b676deSHan-Chung Wang 
14036479a5a4SAndrzej Warzyński     auto inBounds = writeOp.getInBoundsValues();
14046479a5a4SAndrzej Warzyński     auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
14056479a5a4SAndrzej Warzyński     if (llvm::is_contained(droppedInBounds, false))
14066479a5a4SAndrzej Warzyński       return failure();
14076479a5a4SAndrzej Warzyński 
140812b676deSHan-Chung Wang     auto resultTargetVecType =
140912b676deSHan-Chung Wang         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1410e01ff823SBenjamin Maxwell                         targetType.getElementType(),
1411e01ff823SBenjamin Maxwell                         targetType.getScalableDims().drop_back(dimsToDrop));
141212b676deSHan-Chung Wang 
1413d193ac4fSHan-Chung Wang     Location loc = writeOp.getLoc();
1414d193ac4fSHan-Chung Wang     SmallVector<OpFoldResult> sizes =
1415d193ac4fSHan-Chung Wang         memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1416d193ac4fSHan-Chung Wang     SmallVector<OpFoldResult> offsets(srcType.getRank(),
1417d193ac4fSHan-Chung Wang                                       rewriter.getIndexAttr(0));
1418d193ac4fSHan-Chung Wang     SmallVector<OpFoldResult> strides(srcType.getRank(),
1419d193ac4fSHan-Chung Wang                                       rewriter.getIndexAttr(1));
14207c83d1bdSHan-Chung Wang     auto resultMemrefType =
14217c83d1bdSHan-Chung Wang         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
14227c83d1bdSHan-Chung Wang             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
14237c83d1bdSHan-Chung Wang             strides));
14241f5807ebSAndrzej Warzyński     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
14251f5807ebSAndrzej Warzyński         writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
142612b676deSHan-Chung Wang 
142712b676deSHan-Chung Wang     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1428d193ac4fSHan-Chung Wang         loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
142912b676deSHan-Chung Wang     auto permMap = getTransferMinorIdentityMap(
143012b676deSHan-Chung Wang         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
143112b676deSHan-Chung Wang 
143212b676deSHan-Chung Wang     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
143312b676deSHan-Chung Wang         loc, resultTargetVecType, writeOp.getVector());
143412b676deSHan-Chung Wang     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
143512b676deSHan-Chung Wang         writeOp, shapeCast, rankedReducedView,
143612b676deSHan-Chung Wang         writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
143712b676deSHan-Chung Wang         // TODO: support mask.
143812b676deSHan-Chung Wang         /*mask=*/Value(), inBoundsAttr);
143912b676deSHan-Chung Wang     return success();
144012b676deSHan-Chung Wang   }
144112b676deSHan-Chung Wang };
144212b676deSHan-Chung Wang 
1443fb7ef637SJakub Kuderski /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1444fb7ef637SJakub Kuderski /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1445fb7ef637SJakub Kuderski /// with the RHS transposed) lowering.
1446fb7ef637SJakub Kuderski struct CanonicalizeContractMatmulToMMT final
1447fb7ef637SJakub Kuderski     : OpRewritePattern<vector::ContractionOp> {
1448fb7ef637SJakub Kuderski   using OpRewritePattern::OpRewritePattern;
1449fb7ef637SJakub Kuderski 
1450fb7ef637SJakub Kuderski   using FilterConstraintType =
1451fb7ef637SJakub Kuderski       std::function<LogicalResult(vector::ContractionOp op)>;
1452fb7ef637SJakub Kuderski 
1453fb7ef637SJakub Kuderski   CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1454fb7ef637SJakub Kuderski                                   FilterConstraintType constraint)
1455fb7ef637SJakub Kuderski       : OpRewritePattern<vector::ContractionOp>(context, benefit),
1456fb7ef637SJakub Kuderski         filter(std::move(constraint)) {}
1457fb7ef637SJakub Kuderski 
1458fb7ef637SJakub Kuderski   LogicalResult matchAndRewrite(vector::ContractionOp op,
1459fb7ef637SJakub Kuderski                                 PatternRewriter &rewriter) const override {
1460fb7ef637SJakub Kuderski     if (failed(filter(op)))
1461fb7ef637SJakub Kuderski       return failure();
1462fb7ef637SJakub Kuderski 
1463fb7ef637SJakub Kuderski     Location loc = op.getLoc();
1464fb7ef637SJakub Kuderski     Value lhs = op.getLhs();
1465fb7ef637SJakub Kuderski     Value rhs = op.getRhs();
1466fb7ef637SJakub Kuderski     Value res = op.getAcc();
1467fb7ef637SJakub Kuderski 
1468fb7ef637SJakub Kuderski     // Set up the parallel/reduction structure in right form.
1469fb7ef637SJakub Kuderski     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1470fe8a62c4SUday Bondhugula     auto infer = [&](MapList m) {
1471fe8a62c4SUday Bondhugula       return AffineMap::inferFromExprList(m, op.getContext());
1472fe8a62c4SUday Bondhugula     };
1473fb7ef637SJakub Kuderski     AffineExpr m;
1474fb7ef637SJakub Kuderski     AffineExpr n;
1475fb7ef637SJakub Kuderski     AffineExpr k;
1476fb7ef637SJakub Kuderski     bindDims(rewriter.getContext(), m, n, k);
1477fb7ef637SJakub Kuderski     static constexpr std::array<int64_t, 2> perm = {1, 0};
1478fb7ef637SJakub Kuderski     auto iteratorTypes = op.getIteratorTypes().getValue();
1479fb7ef637SJakub Kuderski     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1480fb7ef637SJakub Kuderski     if (iteratorTypes.size() != 3 ||
1481fb7ef637SJakub Kuderski         !vector::isParallelIterator(iteratorTypes[0]) ||
1482fb7ef637SJakub Kuderski         !vector::isParallelIterator(iteratorTypes[1]) ||
1483fb7ef637SJakub Kuderski         !vector::isReductionIterator(iteratorTypes[2]))
1484fb7ef637SJakub Kuderski       return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1485fb7ef637SJakub Kuderski 
1486fb7ef637SJakub Kuderski     // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1487fb7ef637SJakub Kuderski     const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1488fb7ef637SJakub Kuderski     if (maps == canonicalForm)
1489fb7ef637SJakub Kuderski       return rewriter.notifyMatchFailure(op, "already in the canonical form");
1490fb7ef637SJakub Kuderski 
1491fb7ef637SJakub Kuderski     // Create a vector transpose making sure to emit zero/sign-extend at the
1492fb7ef637SJakub Kuderski     // end.
1493fb7ef637SJakub Kuderski     auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1494fb7ef637SJakub Kuderski       if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1495fb7ef637SJakub Kuderski         Value trans =
1496fb7ef637SJakub Kuderski             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
14975041fe84SLei Zhang         VectorType newType =
14984b2ba5a6SAndrzej Warzyński             cast<VectorType>(trans.getType())
14994b2ba5a6SAndrzej Warzyński                 .clone(cast<VectorType>(mat.getType()).getElementType());
15005041fe84SLei Zhang         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1501fb7ef637SJakub Kuderski       }
1502fb7ef637SJakub Kuderski       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1503fb7ef637SJakub Kuderski         Value trans =
1504fb7ef637SJakub Kuderski             rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
15055041fe84SLei Zhang         VectorType newType =
15065041fe84SLei Zhang             VectorType::get(cast<VectorType>(trans.getType()).getShape(),
15075041fe84SLei Zhang                             cast<VectorType>(mat.getType()).getElementType());
15085041fe84SLei Zhang         return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1509fb7ef637SJakub Kuderski       }
1510fb7ef637SJakub Kuderski       return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1511fb7ef637SJakub Kuderski     };
1512fb7ef637SJakub Kuderski 
1513fb7ef637SJakub Kuderski     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1514fb7ef637SJakub Kuderski       rhs = createTranspose(rhs);
1515fb7ef637SJakub Kuderski     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1516fb7ef637SJakub Kuderski       lhs = createTranspose(lhs);
1517fb7ef637SJakub Kuderski     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1518fb7ef637SJakub Kuderski       rhs = createTranspose(rhs);
1519fb7ef637SJakub Kuderski       lhs = createTranspose(lhs);
1520fb7ef637SJakub Kuderski     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1521fb7ef637SJakub Kuderski       std::swap(rhs, lhs);
1522fb7ef637SJakub Kuderski       rhs = createTranspose(rhs);
1523fb7ef637SJakub Kuderski       lhs = createTranspose(lhs);
1524fb7ef637SJakub Kuderski     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1525fb7ef637SJakub Kuderski       std::swap(rhs, lhs);
1526fb7ef637SJakub Kuderski       rhs = createTranspose(rhs);
1527fb7ef637SJakub Kuderski     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1528fb7ef637SJakub Kuderski       std::swap(lhs, rhs);
1529fb7ef637SJakub Kuderski       lhs = createTranspose(lhs);
1530fb7ef637SJakub Kuderski     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1531fb7ef637SJakub Kuderski       std::swap(lhs, rhs);
1532fb7ef637SJakub Kuderski     } else {
1533fb7ef637SJakub Kuderski       return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1534fb7ef637SJakub Kuderski     }
1535fb7ef637SJakub Kuderski     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1536fb7ef637SJakub Kuderski         op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1537fb7ef637SJakub Kuderski         op.getIteratorTypes());
1538fb7ef637SJakub Kuderski     return success();
1539fb7ef637SJakub Kuderski   };
1540fb7ef637SJakub Kuderski 
1541fb7ef637SJakub Kuderski private:
1542fb7ef637SJakub Kuderski   FilterConstraintType filter;
1543fb7ef637SJakub Kuderski };
1544fb7ef637SJakub Kuderski 
15459a795f0cSManish Gupta /// Pattern to fold arithmetic extensions on floating point data types into
15469a795f0cSManish Gupta /// vector contraction operations. linalg.matmul introduces arithmetic
15479a795f0cSManish Gupta /// extensions on its operands. Please mlir snippets below for more details.
15489a795f0cSManish Gupta /// ```mlir
15499a795f0cSManish Gupta ///   "linalg.matmul"(%lhs, %rhs, %acc) ({
15509a795f0cSManish Gupta ///      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
15519a795f0cSManish Gupta ///        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
15529a795f0cSManish Gupta ///        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
15539a795f0cSManish Gupta ///        %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
15549a795f0cSManish Gupta ///        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
15559a795f0cSManish Gupta ///        "linalg.yield"(%acc) : (f32) -> ()
15569a795f0cSManish Gupta ///     })
15579a795f0cSManish Gupta /// ```
15589a795f0cSManish Gupta /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
15599a795f0cSManish Gupta /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
15609a795f0cSManish Gupta /// This pattern folds the arithmetic extensions into the vector contraction and
15619a795f0cSManish Gupta /// enables the usage of native mixed precision Tensor Core instructions.
1562ac1e22f3SStanley Winata template <typename ExtOp>
15639a795f0cSManish Gupta struct FoldArithExtIntoContractionOp
15649a795f0cSManish Gupta     : public OpRewritePattern<vector::ContractionOp> {
15659a795f0cSManish Gupta   using OpRewritePattern::OpRewritePattern;
15669a795f0cSManish Gupta 
15679a795f0cSManish Gupta   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
15689a795f0cSManish Gupta                                 PatternRewriter &rewriter) const override {
15699a795f0cSManish Gupta 
1570ac1e22f3SStanley Winata     auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1571ac1e22f3SStanley Winata     auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
15729a795f0cSManish Gupta 
15739a795f0cSManish Gupta     if (!lhsDefOp || !rhsDefOp) {
15749a795f0cSManish Gupta       return rewriter.notifyMatchFailure(contractOp,
15759a795f0cSManish Gupta                                          "no defining op on contract operands");
15769a795f0cSManish Gupta     }
15779a795f0cSManish Gupta 
15789a795f0cSManish Gupta     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
15799a795f0cSManish Gupta         contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
15809a795f0cSManish Gupta         contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
15819a795f0cSManish Gupta         contractOp.getIteratorTypesAttr());
15829a795f0cSManish Gupta 
15839a795f0cSManish Gupta     return success();
15849a795f0cSManish Gupta   }
15859a795f0cSManish Gupta };
15869a795f0cSManish Gupta 
1587d33bad66SJakub Kuderski /// Pattern to fold chained reduction to a series of vector additions and a
1588d33bad66SJakub Kuderski /// final reduction. This form should require fewer subgroup operations.
1589d33bad66SJakub Kuderski ///
1590d33bad66SJakub Kuderski /// ```mlir
1591d33bad66SJakub Kuderski /// %a = vector.reduction <add> %x, %acc
1592d33bad66SJakub Kuderski /// %b = vector.reduction <add> %y, %a
1593d33bad66SJakub Kuderski ///  ==>
1594d33bad66SJakub Kuderski /// %a = arith.addf %x, %y
1595d33bad66SJakub Kuderski /// %b = vector.reduction <add> %a, %acc
1596d33bad66SJakub Kuderski /// ```
1597d33bad66SJakub Kuderski struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1598d33bad66SJakub Kuderski   using OpRewritePattern::OpRewritePattern;
1599d33bad66SJakub Kuderski 
1600d33bad66SJakub Kuderski   LogicalResult matchAndRewrite(vector::ReductionOp op,
1601d33bad66SJakub Kuderski                                 PatternRewriter &rewriter) const override {
1602d33bad66SJakub Kuderski     // TODO: Handle other combining kinds.
1603d33bad66SJakub Kuderski     if (op.getKind() != vector::CombiningKind::ADD)
1604d33bad66SJakub Kuderski       return failure();
1605d33bad66SJakub Kuderski 
1606d33bad66SJakub Kuderski     // Accumulator is optional.
1607d33bad66SJakub Kuderski     Value acc = op.getAcc();
1608d33bad66SJakub Kuderski     if (!acc)
1609d33bad66SJakub Kuderski       return failure();
1610d33bad66SJakub Kuderski 
1611d33bad66SJakub Kuderski     if (!acc.getType().isIntOrFloat())
1612d33bad66SJakub Kuderski       return failure();
1613d33bad66SJakub Kuderski 
1614d33bad66SJakub Kuderski     auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1615d33bad66SJakub Kuderski     if (!parentReduction)
1616d33bad66SJakub Kuderski       return failure();
1617d33bad66SJakub Kuderski 
1618d33bad66SJakub Kuderski     Location loc = op.getLoc();
1619d33bad66SJakub Kuderski     Value vAdd;
1620d33bad66SJakub Kuderski     if (isa<IntegerType>(acc.getType())) {
1621d33bad66SJakub Kuderski       vAdd = rewriter.createOrFold<arith::AddIOp>(
1622d33bad66SJakub Kuderski           loc, parentReduction.getVector(), op.getVector());
1623d33bad66SJakub Kuderski     } else {
1624d33bad66SJakub Kuderski       vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1625d33bad66SJakub Kuderski                                             op.getVector());
1626d33bad66SJakub Kuderski     }
1627d33bad66SJakub Kuderski     rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1628d33bad66SJakub Kuderski                                                      parentReduction.getAcc());
1629d33bad66SJakub Kuderski     return success();
1630d33bad66SJakub Kuderski   }
1631d33bad66SJakub Kuderski };
1632d33bad66SJakub Kuderski 
1633de61875eSHugo Trachino // Helper function dropping unit non-scalable dimension from a VectorType
1634de61875eSHugo Trachino // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1635de61875eSHugo Trachino // dimensions are not dropped. Folding such dimensions would require "shifting"
1636de61875eSHugo Trachino // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1637de61875eSHugo Trachino // vector<[4]xf32>). This could be implemented in the future.
1638de61875eSHugo Trachino static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1639de61875eSHugo Trachino   auto inVecShape = inVecTy.getShape();
1640de61875eSHugo Trachino   SmallVector<int64_t> newShape;
1641de61875eSHugo Trachino   SmallVector<bool> newScalableDims;
1642de61875eSHugo Trachino   for (auto [dim, isScalable] :
1643de61875eSHugo Trachino        llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1644de61875eSHugo Trachino     if (dim == 1 && !isScalable)
1645de61875eSHugo Trachino       continue;
1646de61875eSHugo Trachino 
1647de61875eSHugo Trachino     newShape.push_back(dim);
1648de61875eSHugo Trachino     newScalableDims.push_back(isScalable);
1649de61875eSHugo Trachino   }
1650de61875eSHugo Trachino   // All dims have been dropped, return vector<1xeType>.
1651de61875eSHugo Trachino   if (newShape.empty()) {
1652de61875eSHugo Trachino     newShape.push_back(1);
1653de61875eSHugo Trachino     newScalableDims.push_back(false);
1654de61875eSHugo Trachino   }
1655de61875eSHugo Trachino 
1656de61875eSHugo Trachino   return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1657de61875eSHugo Trachino }
1658de61875eSHugo Trachino 
1659de61875eSHugo Trachino /// For vectors with at least one unit dim, replaces:
1660c02d07fdSAndrzej Warzyński ///   elementwise(a, b)
1661c02d07fdSAndrzej Warzyński /// with:
1662c02d07fdSAndrzej Warzyński ///   sc_a = shape_cast(a)
1663c02d07fdSAndrzej Warzyński ///   sc_b = shape_cast(b)
1664c02d07fdSAndrzej Warzyński ///   res = elementwise(sc_a, sc_b)
1665c02d07fdSAndrzej Warzyński ///   return shape_cast(res)
1666c02d07fdSAndrzej Warzyński /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1667c02d07fdSAndrzej Warzyński /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1668c02d07fdSAndrzej Warzyński /// required to be rank > 1.
1669c02d07fdSAndrzej Warzyński ///
1670c02d07fdSAndrzej Warzyński /// Ex:
1671c02d07fdSAndrzej Warzyński ///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1672c02d07fdSAndrzej Warzyński ///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1673c02d07fdSAndrzej Warzyński ///
1674c02d07fdSAndrzej Warzyński /// gets converted to:
1675c02d07fdSAndrzej Warzyński ///
1676c02d07fdSAndrzej Warzyński ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1677c02d07fdSAndrzej Warzyński ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1678c02d07fdSAndrzej Warzyński ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1679c02d07fdSAndrzej Warzyński ///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1680c02d07fdSAndrzej Warzyński ///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1681c02d07fdSAndrzej Warzyński ///
1682c02d07fdSAndrzej Warzyński /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1683c02d07fdSAndrzej Warzyński /// `%cast`.
1684c02d07fdSAndrzej Warzyński struct DropUnitDimFromElementwiseOps final
1685c02d07fdSAndrzej Warzyński     : public OpTraitRewritePattern<OpTrait::Elementwise> {
1686c02d07fdSAndrzej Warzyński   using OpTraitRewritePattern::OpTraitRewritePattern;
1687c02d07fdSAndrzej Warzyński   LogicalResult matchAndRewrite(Operation *op,
1688c02d07fdSAndrzej Warzyński                                 PatternRewriter &rewriter) const override {
16892c9ba9c3SJerry Wu     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1690c02d07fdSAndrzej Warzyński       return failure();
1691c02d07fdSAndrzej Warzyński 
16922c9ba9c3SJerry Wu     auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
16932c9ba9c3SJerry Wu     if (!resultVectorType)
16942c9ba9c3SJerry Wu       return failure();
16952c9ba9c3SJerry Wu 
169613b37626SDiego Caballero     // Check the operand pre-conditions. For `Elementwise` ops all operands are
169713b37626SDiego Caballero     // guaranteed to have identical shapes (with some exceptions such as
169813b37626SDiego Caballero     // `arith.select`) and it suffices to only check one of them.
169913b37626SDiego Caballero     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1700eaabd762SHan-Chung Wang     if (!sourceVectorType)
170113b37626SDiego Caballero       return failure();
1702eaabd762SHan-Chung Wang     if (sourceVectorType.getRank() < 2)
1703eaabd762SHan-Chung Wang       return failure();
1704eaabd762SHan-Chung Wang 
1705c02d07fdSAndrzej Warzyński     SmallVector<Value> newOperands;
1706c02d07fdSAndrzej Warzyński     auto loc = op->getLoc();
1707c02d07fdSAndrzej Warzyński     for (auto operand : op->getOperands()) {
17082c9ba9c3SJerry Wu       auto opVectorType = cast<VectorType>(operand.getType());
1709de61875eSHugo Trachino       auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1710de61875eSHugo Trachino       if (newVType == opVectorType)
1711de61875eSHugo Trachino         return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1712de61875eSHugo Trachino 
1713c02d07fdSAndrzej Warzyński       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1714c02d07fdSAndrzej Warzyński       newOperands.push_back(opSC);
1715c02d07fdSAndrzej Warzyński     }
1716c02d07fdSAndrzej Warzyński 
17172c9ba9c3SJerry Wu     VectorType newResultVectorType =
1718de61875eSHugo Trachino         dropNonScalableUnitDimFromType(resultVectorType);
1719de61875eSHugo Trachino     // Create an updated elementwise Op without unit dim.
1720c02d07fdSAndrzej Warzyński     Operation *elementwiseOp =
1721c02d07fdSAndrzej Warzyński         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
17222c9ba9c3SJerry Wu                         newResultVectorType, op->getAttrs());
1723c02d07fdSAndrzej Warzyński 
1724de61875eSHugo Trachino     // Restore the unit dim by applying vector.shape_cast to the result.
17252c9ba9c3SJerry Wu     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1726c02d07fdSAndrzej Warzyński                                              elementwiseOp->getResult(0));
1727c02d07fdSAndrzej Warzyński 
1728c02d07fdSAndrzej Warzyński     return success();
1729c02d07fdSAndrzej Warzyński   }
1730c02d07fdSAndrzej Warzyński };
1731c02d07fdSAndrzej Warzyński 
1732da8778e4SBenjamin Maxwell /// A pattern to drop unit dims from vector.transpose.
1733da8778e4SBenjamin Maxwell ///
1734da8778e4SBenjamin Maxwell /// Example:
1735da8778e4SBenjamin Maxwell ///
1736da8778e4SBenjamin Maxwell ///  BEFORE:
1737da8778e4SBenjamin Maxwell ///  ```mlir
1738da8778e4SBenjamin Maxwell ///  %transpose = vector.transpose %vector, [3, 0, 1, 2]
1739da8778e4SBenjamin Maxwell ///    : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1740da8778e4SBenjamin Maxwell ///  ```
1741da8778e4SBenjamin Maxwell ///
1742da8778e4SBenjamin Maxwell ///  AFTER:
1743da8778e4SBenjamin Maxwell ///  ```mlir
1744da8778e4SBenjamin Maxwell ///  %dropDims = vector.shape_cast %vector
1745da8778e4SBenjamin Maxwell ///    : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1746da8778e4SBenjamin Maxwell ///  %transpose = vector.transpose %0, [1, 0]
1747da8778e4SBenjamin Maxwell ///    : vector<4x[4]xf32> to vector<[4]x4xf32>
1748da8778e4SBenjamin Maxwell ///  %restoreDims = vector.shape_cast %transpose
1749da8778e4SBenjamin Maxwell ///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1750da8778e4SBenjamin Maxwell ///  ```
1751da8778e4SBenjamin Maxwell struct DropUnitDimsFromTransposeOp final
1752da8778e4SBenjamin Maxwell     : OpRewritePattern<vector::TransposeOp> {
1753da8778e4SBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
1754da8778e4SBenjamin Maxwell 
1755da8778e4SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::TransposeOp op,
1756da8778e4SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
1757da8778e4SBenjamin Maxwell     VectorType sourceType = op.getSourceVectorType();
1758da8778e4SBenjamin Maxwell     VectorType sourceTypeWithoutUnitDims =
1759da8778e4SBenjamin Maxwell         dropNonScalableUnitDimFromType(sourceType);
1760da8778e4SBenjamin Maxwell 
1761da8778e4SBenjamin Maxwell     if (sourceType == sourceTypeWithoutUnitDims)
1762da8778e4SBenjamin Maxwell       return failure();
1763da8778e4SBenjamin Maxwell 
1764da8778e4SBenjamin Maxwell     // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1765da8778e4SBenjamin Maxwell     auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
1766da8778e4SBenjamin Maxwell     SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1767da8778e4SBenjamin Maxwell     int64_t droppedDims = 0;
1768da8778e4SBenjamin Maxwell     for (auto [i, dim] : llvm::enumerate(sourceDims)) {
1769da8778e4SBenjamin Maxwell       droppedDimsBefore[i] = droppedDims;
1770da8778e4SBenjamin Maxwell       if (dim == std::make_tuple(1, false))
1771da8778e4SBenjamin Maxwell         ++droppedDims;
1772da8778e4SBenjamin Maxwell     }
1773da8778e4SBenjamin Maxwell 
1774da8778e4SBenjamin Maxwell     // Drop unit dims from transpose permutation.
1775da8778e4SBenjamin Maxwell     ArrayRef<int64_t> perm = op.getPermutation();
1776da8778e4SBenjamin Maxwell     SmallVector<int64_t> newPerm;
1777da8778e4SBenjamin Maxwell     for (int64_t idx : perm) {
1778da8778e4SBenjamin Maxwell       if (sourceDims[idx] == std::make_tuple(1, false))
1779da8778e4SBenjamin Maxwell         continue;
1780da8778e4SBenjamin Maxwell       newPerm.push_back(idx - droppedDimsBefore[idx]);
1781da8778e4SBenjamin Maxwell     }
1782da8778e4SBenjamin Maxwell 
1783201da87cSHan-Chung Wang     // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1784201da87cSHan-Chung Wang     // type when the dimensions are unit dimensions. In this case, the newPerm
1785201da87cSHan-Chung Wang     // should be [0].
1786201da87cSHan-Chung Wang     if (newPerm.empty()) {
1787201da87cSHan-Chung Wang       newPerm.push_back(0);
1788201da87cSHan-Chung Wang     }
1789201da87cSHan-Chung Wang 
1790da8778e4SBenjamin Maxwell     Location loc = op.getLoc();
1791da8778e4SBenjamin Maxwell     // Drop the unit dims via shape_cast.
1792da8778e4SBenjamin Maxwell     auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
1793da8778e4SBenjamin Maxwell         loc, sourceTypeWithoutUnitDims, op.getVector());
1794da8778e4SBenjamin Maxwell     // Create the new transpose.
1795*aa295216SJay Foad     auto transposeWithoutUnitDims =
1796da8778e4SBenjamin Maxwell         rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1797da8778e4SBenjamin Maxwell     // Restore the unit dims via shape cast.
1798da8778e4SBenjamin Maxwell     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
1799*aa295216SJay Foad         op, op.getResultVectorType(), transposeWithoutUnitDims);
1800da8778e4SBenjamin Maxwell 
1801da492d43SBenjamin Maxwell     return success();
1802da8778e4SBenjamin Maxwell   }
1803da8778e4SBenjamin Maxwell };
1804da8778e4SBenjamin Maxwell 
1805a3b34e67SQuinn Dawkins /// A pattern to drop unit dims from the iter_args of an scf.for.
1806a3b34e67SQuinn Dawkins ///
1807a3b34e67SQuinn Dawkins /// Example:
1808a3b34e67SQuinn Dawkins ///
1809a3b34e67SQuinn Dawkins ///  BEFORE:
1810a3b34e67SQuinn Dawkins ///  ```mlir
1811a3b34e67SQuinn Dawkins ///  %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
1812a3b34e67SQuinn Dawkins ///    ...
1813a3b34e67SQuinn Dawkins ///    scf.yield %
1814a3b34e67SQuinn Dawkins ///  }
1815a3b34e67SQuinn Dawkins ///  ```
1816a3b34e67SQuinn Dawkins ///
1817a3b34e67SQuinn Dawkins ///  AFTER:
1818a3b34e67SQuinn Dawkins ///  ```mlir
1819a3b34e67SQuinn Dawkins ///  %drop = vector.shape_cast %init
1820a3b34e67SQuinn Dawkins ///    : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
1821a3b34e67SQuinn Dawkins ///  %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
1822a3b34e67SQuinn Dawkins ///    %new_iter = vector.shape_cast %iter
1823a3b34e67SQuinn Dawkins ///      : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1824a3b34e67SQuinn Dawkins ///    ...
1825a3b34e67SQuinn Dawkins ///  }
1826a3b34e67SQuinn Dawkins ///  %res = vector.shape_cast %new_loop
1827a3b34e67SQuinn Dawkins ///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1828a3b34e67SQuinn Dawkins ///  ```
1829a3b34e67SQuinn Dawkins struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
1830a3b34e67SQuinn Dawkins   using OpRewritePattern::OpRewritePattern;
1831a3b34e67SQuinn Dawkins 
1832a3b34e67SQuinn Dawkins   LogicalResult matchAndRewrite(scf::ForOp forOp,
1833a3b34e67SQuinn Dawkins                                 PatternRewriter &rewriter) const override {
1834a3b34e67SQuinn Dawkins     /// Find the first iter_arg with droppable unit dims. Further applications
1835a3b34e67SQuinn Dawkins     /// of this pattern will apply to later arguments.
1836a3b34e67SQuinn Dawkins     for (OpOperand &operand : forOp.getInitArgsMutable()) {
1837a3b34e67SQuinn Dawkins       auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1838a3b34e67SQuinn Dawkins       if (!vectorType)
1839a3b34e67SQuinn Dawkins         continue;
1840a3b34e67SQuinn Dawkins 
1841a3b34e67SQuinn Dawkins       VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1842a3b34e67SQuinn Dawkins       if (vectorType == newVectorType)
1843a3b34e67SQuinn Dawkins         continue;
1844a3b34e67SQuinn Dawkins 
1845a3b34e67SQuinn Dawkins       // Create a new ForOp with that iter operand replaced.
1846a3b34e67SQuinn Dawkins       auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
1847a3b34e67SQuinn Dawkins         return b.create<vector::ShapeCastOp>(loc, type, source);
1848a3b34e67SQuinn Dawkins       };
1849a3b34e67SQuinn Dawkins 
1850a3b34e67SQuinn Dawkins       Value replacement =
1851a3b34e67SQuinn Dawkins           castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1852a3b34e67SQuinn Dawkins       rewriter.replaceOp(forOp,
1853a3b34e67SQuinn Dawkins                          replaceAndCastForOpIterArg(rewriter, forOp, operand,
1854a3b34e67SQuinn Dawkins                                                     replacement, castFn));
1855a3b34e67SQuinn Dawkins       return success();
1856a3b34e67SQuinn Dawkins     }
1857a3b34e67SQuinn Dawkins     return failure();
1858a3b34e67SQuinn Dawkins   }
1859a3b34e67SQuinn Dawkins };
1860a3b34e67SQuinn Dawkins 
1861d33bad66SJakub Kuderski /// Pattern to eliminate redundant zero-constants added to reduction operands.
1862d33bad66SJakub Kuderski /// It's enough for there to be one initial zero value, so we can eliminate the
1863d33bad66SJakub Kuderski /// extra ones that feed into `vector.reduction <add>`. These get created by the
1864d33bad66SJakub Kuderski /// `ChainedReduction` pattern.
1865d33bad66SJakub Kuderski ///
1866d33bad66SJakub Kuderski /// ```mlir
1867d33bad66SJakub Kuderski /// %a = arith.addf %x, %zero
1868d33bad66SJakub Kuderski /// %b = arith.addf %a, %y
1869d33bad66SJakub Kuderski /// %c = vector.reduction <add> %b, %acc
1870d33bad66SJakub Kuderski ///  ==>
1871d33bad66SJakub Kuderski /// %b = arith.addf %a, %y
1872d33bad66SJakub Kuderski /// %c = vector.reduction <add> %b, %acc
1873d33bad66SJakub Kuderski /// ```
1874d33bad66SJakub Kuderski struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1875d33bad66SJakub Kuderski   using OpRewritePattern::OpRewritePattern;
1876d33bad66SJakub Kuderski 
1877d33bad66SJakub Kuderski   LogicalResult matchAndRewrite(vector::ReductionOp op,
1878d33bad66SJakub Kuderski                                 PatternRewriter &rewriter) const override {
1879d33bad66SJakub Kuderski     // TODO: Handle other reduction kinds and their identity values.
1880d33bad66SJakub Kuderski     if (op.getKind() != vector::CombiningKind::ADD)
1881d33bad66SJakub Kuderski       return failure();
1882d33bad66SJakub Kuderski 
1883d33bad66SJakub Kuderski     Type elemType = op.getSourceVectorType().getElementType();
1884d33bad66SJakub Kuderski     // The integer case should be handled by `arith.addi` folders, only check
1885d33bad66SJakub Kuderski     // for floats here.
1886d33bad66SJakub Kuderski     if (!isa<FloatType>(elemType))
1887d33bad66SJakub Kuderski       return failure();
1888d33bad66SJakub Kuderski 
1889d33bad66SJakub Kuderski     auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1890d33bad66SJakub Kuderski     if (!vAdd)
1891d33bad66SJakub Kuderski       return failure();
1892d33bad66SJakub Kuderski     auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1893d33bad66SJakub Kuderski     if (!addLhs)
1894d33bad66SJakub Kuderski       return failure();
1895d33bad66SJakub Kuderski 
1896d33bad66SJakub Kuderski     if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1897d33bad66SJakub Kuderski       return failure();
1898d33bad66SJakub Kuderski 
1899d33bad66SJakub Kuderski     auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1900d33bad66SJakub Kuderski                                                  vAdd.getRhs());
1901d33bad66SJakub Kuderski     rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1902d33bad66SJakub Kuderski                                                      op.getAcc());
1903d33bad66SJakub Kuderski     return success();
1904d33bad66SJakub Kuderski   }
1905d33bad66SJakub Kuderski };
1906d33bad66SJakub Kuderski 
190707677113SJakub Kuderski /// Example:
190807677113SJakub Kuderski /// ```
190907677113SJakub Kuderski /// %a = vector.reduction <add> %x : vector<2xf32> into f32
191007677113SJakub Kuderski /// ```
191107677113SJakub Kuderski /// is transformed into:
191207677113SJakub Kuderski /// ```
191307677113SJakub Kuderski /// %y = vector.extract %x[0] : f32 from vector<2xf32>
191407677113SJakub Kuderski /// %z = vector.extract %x[1] : f32 from vector<2xf32>
191507677113SJakub Kuderski /// %a = arith.addf %y, %z : f32
191607677113SJakub Kuderski /// ```
191707677113SJakub Kuderski struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
191807677113SJakub Kuderski   BreakDownVectorReduction(MLIRContext *context,
191907677113SJakub Kuderski                            unsigned maxNumElementsToExtract,
192007677113SJakub Kuderski                            PatternBenefit benefit)
192107677113SJakub Kuderski       : OpRewritePattern(context, benefit),
192207677113SJakub Kuderski         maxNumElementsToExtract(maxNumElementsToExtract) {}
192307677113SJakub Kuderski 
192407677113SJakub Kuderski   LogicalResult matchAndRewrite(vector::ReductionOp op,
192507677113SJakub Kuderski                                 PatternRewriter &rewriter) const override {
192607677113SJakub Kuderski     VectorType type = op.getSourceVectorType();
192707677113SJakub Kuderski     if (type.isScalable() || op.isMasked())
192807677113SJakub Kuderski       return failure();
192907677113SJakub Kuderski     assert(type.getRank() == 1 && "Expected a 1-d vector");
193007677113SJakub Kuderski 
193107677113SJakub Kuderski     int64_t numElems = type.getNumElements();
193207677113SJakub Kuderski     if (numElems > maxNumElementsToExtract) {
193307677113SJakub Kuderski       return rewriter.notifyMatchFailure(
193407677113SJakub Kuderski           op, llvm::formatv("has too many vector elements ({0}) to break down "
193507677113SJakub Kuderski                             "(max allowed: {1})",
193607677113SJakub Kuderski                             numElems, maxNumElementsToExtract));
193707677113SJakub Kuderski     }
193807677113SJakub Kuderski 
193907677113SJakub Kuderski     Location loc = op.getLoc();
194007677113SJakub Kuderski     SmallVector<Value> extracted(numElems, nullptr);
194107677113SJakub Kuderski     for (auto [idx, extractedElem] : llvm::enumerate(extracted))
194207677113SJakub Kuderski       extractedElem = rewriter.create<vector::ExtractOp>(
194307677113SJakub Kuderski           loc, op.getVector(), static_cast<int64_t>(idx));
194407677113SJakub Kuderski 
194507677113SJakub Kuderski     Value res = extracted.front();
194607677113SJakub Kuderski     for (auto extractedElem : llvm::drop_begin(extracted))
194707677113SJakub Kuderski       res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
194807677113SJakub Kuderski                                        extractedElem, op.getFastmathAttr());
194907677113SJakub Kuderski     if (Value acc = op.getAcc())
195007677113SJakub Kuderski       res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
195107677113SJakub Kuderski                                        op.getFastmathAttr());
195207677113SJakub Kuderski 
195307677113SJakub Kuderski     rewriter.replaceOp(op, res);
195407677113SJakub Kuderski     return success();
195507677113SJakub Kuderski   }
195607677113SJakub Kuderski 
195707677113SJakub Kuderski private:
195807677113SJakub Kuderski   unsigned maxNumElementsToExtract = 0;
195907677113SJakub Kuderski };
196007677113SJakub Kuderski 
19619f0aa05bSHugo Trachino /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
19629f0aa05bSHugo Trachino /// B)`.
19639f0aa05bSHugo Trachino /// Example:
19649f0aa05bSHugo Trachino ///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
19659f0aa05bSHugo Trachino ///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
19669f0aa05bSHugo Trachino ///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
19679f0aa05bSHugo Trachino ///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
19689f0aa05bSHugo Trachino ///
19699f0aa05bSHugo Trachino /// Becomes :
19709f0aa05bSHugo Trachino ///
19719f0aa05bSHugo Trachino ///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
19729f0aa05bSHugo Trachino ///
19739f0aa05bSHugo Trachino /// Supports only 1D-to-2D broadcasts. The following cases are not supported.
19749f0aa05bSHugo Trachino /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
19759f0aa05bSHugo Trachino /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
19769f0aa05bSHugo Trachino /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
19779f0aa05bSHugo Trachino template <typename MulOpType>
19789f0aa05bSHugo Trachino struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
19799f0aa05bSHugo Trachino   using OpRewritePattern<MulOpType>::OpRewritePattern;
19809f0aa05bSHugo Trachino   // Returns whether a vector.broadcast matches requirements for an outerproduct
19819f0aa05bSHugo Trachino   // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
19829f0aa05bSHugo Trachino   bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
19839f0aa05bSHugo Trachino     // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
19849f0aa05bSHugo Trachino     // shape_casts/broadcasts which does not belong in this pattern.
19859f0aa05bSHugo Trachino     if (!broadcastOp.computeBroadcastedUnitDims().empty())
19869f0aa05bSHugo Trachino       return false;
19879f0aa05bSHugo Trachino     // Avoid broadcast like f32 or vector<f32> -> ResType
19889f0aa05bSHugo Trachino     auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
19899f0aa05bSHugo Trachino     return srcType && srcType.getRank() != 2;
19909f0aa05bSHugo Trachino   }
19919f0aa05bSHugo Trachino 
19929f0aa05bSHugo Trachino   LogicalResult matchAndRewrite(MulOpType mulOp,
19939f0aa05bSHugo Trachino                                 PatternRewriter &rewriter) const override {
19949f0aa05bSHugo Trachino     auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
19959f0aa05bSHugo Trachino     if (!resType)
19969f0aa05bSHugo Trachino       return failure();
19979f0aa05bSHugo Trachino     if (resType.getRank() != 2)
19989f0aa05bSHugo Trachino       return failure();
19999f0aa05bSHugo Trachino     /// If operandA can be written as tr(broadcast(A)) and operandB as
20009f0aa05bSHugo Trachino     /// broadcast(B) where broadcasts are 1D-to-2D, create and return
20019f0aa05bSHugo Trachino     /// vector.outerproduct(A, B). Returns failure() otherwise.
20029f0aa05bSHugo Trachino     auto matchOuterProduct =
20039f0aa05bSHugo Trachino         [&](Value operandA,
20049f0aa05bSHugo Trachino             Value operandB) -> FailureOr<vector::OuterProductOp> {
20059f0aa05bSHugo Trachino       auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
20069f0aa05bSHugo Trachino       if (!transposedLhs)
20079f0aa05bSHugo Trachino         return failure();
20089f0aa05bSHugo Trachino       // Fail unless this is a true 2-D matrix transpose.
20099f0aa05bSHugo Trachino       ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
20109f0aa05bSHugo Trachino       if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
20119f0aa05bSHugo Trachino         return failure();
20129f0aa05bSHugo Trachino 
20139f0aa05bSHugo Trachino       auto broadcastedLhs =
20149f0aa05bSHugo Trachino           transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
20159f0aa05bSHugo Trachino       if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
20169f0aa05bSHugo Trachino         return failure();
20179f0aa05bSHugo Trachino 
20189f0aa05bSHugo Trachino       auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
20199f0aa05bSHugo Trachino       if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
20209f0aa05bSHugo Trachino         return failure();
20219f0aa05bSHugo Trachino 
20229f0aa05bSHugo Trachino       return rewriter.create<vector::OuterProductOp>(
20239f0aa05bSHugo Trachino           mulOp->getLoc(), resType, broadcastedLhs.getSource(),
20249f0aa05bSHugo Trachino           broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
20259f0aa05bSHugo Trachino     };
20269f0aa05bSHugo Trachino 
20279f0aa05bSHugo Trachino     Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
20289f0aa05bSHugo Trachino     auto maybeOuterP = matchOuterProduct(lhs, rhs);
20299f0aa05bSHugo Trachino     // Handle commutativity, the transposed op is the outerproduct LHS.
20309f0aa05bSHugo Trachino     if (failed(maybeOuterP))
20319f0aa05bSHugo Trachino       maybeOuterP = matchOuterProduct(rhs, lhs);
20329f0aa05bSHugo Trachino     if (failed(maybeOuterP))
20339f0aa05bSHugo Trachino       return failure();
20349f0aa05bSHugo Trachino     rewriter.replaceOp(mulOp, maybeOuterP->getResult());
20359f0aa05bSHugo Trachino     return success();
20369f0aa05bSHugo Trachino   }
20379f0aa05bSHugo Trachino };
20389f0aa05bSHugo Trachino 
203999ef9eebSMatthias Springer } // namespace
204099ef9eebSMatthias Springer 
20419a795f0cSManish Gupta void mlir::vector::populateFoldArithExtensionPatterns(
20429a795f0cSManish Gupta     RewritePatternSet &patterns) {
2043ac1e22f3SStanley Winata   patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2044ac1e22f3SStanley Winata                FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2045ac1e22f3SStanley Winata       patterns.getContext());
20469a795f0cSManish Gupta }
20479a795f0cSManish Gupta 
204899ef9eebSMatthias Springer void mlir::vector::populateVectorMaskMaterializationPatterns(
204927cc31b6SNicolas Vasilache     RewritePatternSet &patterns, bool force32BitVectorIndices,
205027cc31b6SNicolas Vasilache     PatternBenefit benefit) {
205199ef9eebSMatthias Springer   patterns.add<VectorCreateMaskOpConversion,
205299ef9eebSMatthias Springer                MaterializeTransferMask<vector::TransferReadOp>,
205399ef9eebSMatthias Springer                MaterializeTransferMask<vector::TransferWriteOp>>(
205427cc31b6SNicolas Vasilache       patterns.getContext(), force32BitVectorIndices, benefit);
205515a08cf2SDiego Caballero   patterns.add<FoldI1Select>(patterns.getContext(), benefit);
205699ef9eebSMatthias Springer }
205799ef9eebSMatthias Springer 
205827cc31b6SNicolas Vasilache void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
205927cc31b6SNicolas Vasilache                                                     PatternBenefit benefit) {
206027cc31b6SNicolas Vasilache   patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
206199ef9eebSMatthias Springer }
206299ef9eebSMatthias Springer 
2063c02d07fdSAndrzej Warzyński void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2064c02d07fdSAndrzej Warzyński     RewritePatternSet &patterns, PatternBenefit benefit) {
20651f5e8263SAndrzej Warzyński   // TODO: Consider either:
20661f5e8263SAndrzej Warzyński   //  * including DropInnerMostUnitDimsTransferRead and
20671f5e8263SAndrzej Warzyński   //    DropInnerMostUnitDimsTransferWrite, or
20681f5e8263SAndrzej Warzyński   //  * better naming to distinguish this and
20691f5e8263SAndrzej Warzyński   //    populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
20701f5e8263SAndrzej Warzyński   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
20711f5e8263SAndrzej Warzyński                DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2072a3b34e67SQuinn Dawkins       patterns.getContext(), benefit);
2073c02d07fdSAndrzej Warzyński }
2074c02d07fdSAndrzej Warzyński 
207599ef9eebSMatthias Springer void mlir::vector::populateBubbleVectorBitCastOpPatterns(
207627cc31b6SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
207799ef9eebSMatthias Springer   patterns.add<BubbleDownVectorBitCastForExtract,
207899ef9eebSMatthias Springer                BubbleDownBitCastForStridedSliceExtract,
20794623c114SDiego Caballero                BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
20804623c114SDiego Caballero       patterns.getContext(), benefit);
208199ef9eebSMatthias Springer }
208299ef9eebSMatthias Springer 
2083650f04feSQuinn Dawkins void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2084650f04feSQuinn Dawkins     RewritePatternSet &patterns,
2085650f04feSQuinn Dawkins     std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2086650f04feSQuinn Dawkins   patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2087650f04feSQuinn Dawkins                                        std::move(controlFn), benefit);
2088650f04feSQuinn Dawkins }
2089650f04feSQuinn Dawkins 
2090fb7ef637SJakub Kuderski void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
2091fb7ef637SJakub Kuderski     RewritePatternSet &patterns,
2092fb7ef637SJakub Kuderski     std::function<LogicalResult(vector::ContractionOp)> constraint,
2093fb7ef637SJakub Kuderski     PatternBenefit benefit) {
2094fb7ef637SJakub Kuderski   patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2095fb7ef637SJakub Kuderski                                                 std::move(constraint));
2096fb7ef637SJakub Kuderski }
2097fb7ef637SJakub Kuderski 
209899ef9eebSMatthias Springer void mlir::vector::populateVectorReductionToContractPatterns(
209927cc31b6SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
210099ef9eebSMatthias Springer   patterns.add<MultiReduceToContract, CombineContractBroadcast,
210142944da5SAndrzej Warzyński                CombineContractABTranspose, CombineContractResultTranspose>(
2102f0c93fd4SLei Zhang       patterns.getContext(), benefit);
210399ef9eebSMatthias Springer }
210499ef9eebSMatthias Springer 
210599ef9eebSMatthias Springer void mlir::vector::
210699ef9eebSMatthias Springer     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
210727cc31b6SNicolas Vasilache         RewritePatternSet &patterns, PatternBenefit benefit) {
210812b676deSHan-Chung Wang   patterns.add<DropInnerMostUnitDimsTransferRead,
210912b676deSHan-Chung Wang                DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
211012b676deSHan-Chung Wang                                                    benefit);
211199ef9eebSMatthias Springer }
211299ef9eebSMatthias Springer 
211342944da5SAndrzej Warzyński void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
211442944da5SAndrzej Warzyński                                                  PatternBenefit benefit) {
211542944da5SAndrzej Warzyński   patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
211642944da5SAndrzej Warzyński                ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
211742944da5SAndrzej Warzyński                                                  benefit);
21184d339ec9SAndrzej Warzynski }
21194d339ec9SAndrzej Warzynski 
2120d33bad66SJakub Kuderski void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2121d33bad66SJakub Kuderski     RewritePatternSet &patterns, PatternBenefit benefit) {
2122d33bad66SJakub Kuderski   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2123d33bad66SJakub Kuderski   patterns.add<ReduceRedundantZero>(patterns.getContext(),
2124d33bad66SJakub Kuderski                                     PatternBenefit(benefit.getBenefit() + 1));
2125d33bad66SJakub Kuderski }
2126d33bad66SJakub Kuderski 
212707677113SJakub Kuderski void mlir::vector::populateBreakDownVectorReductionPatterns(
212807677113SJakub Kuderski     RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
212907677113SJakub Kuderski     PatternBenefit benefit) {
213007677113SJakub Kuderski   patterns.add<BreakDownVectorReduction>(patterns.getContext(),
213107677113SJakub Kuderski                                          maxNumElementsToExtract, benefit);
213207677113SJakub Kuderski }
213307677113SJakub Kuderski 
21349f0aa05bSHugo Trachino void mlir::vector::populateElementwiseToVectorOpsPatterns(
21359f0aa05bSHugo Trachino     RewritePatternSet &patterns) {
21369f0aa05bSHugo Trachino   patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
21379f0aa05bSHugo Trachino                FoldArithToVectorOuterProduct<arith::MulIOp>>(
21389f0aa05bSHugo Trachino       patterns.getContext());
21399f0aa05bSHugo Trachino }
21409f0aa05bSHugo Trachino 
2141edec4239SQuentin Colombet //===----------------------------------------------------------------------===//
2142edec4239SQuentin Colombet // TableGen'd enum attribute definitions
2143edec4239SQuentin Colombet //===----------------------------------------------------------------------===//
2144edec4239SQuentin Colombet 
2145edec4239SQuentin Colombet #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
2146