199ef9eebSMatthias Springer //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 299ef9eebSMatthias Springer // 399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 699ef9eebSMatthias Springer // 799ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 899ef9eebSMatthias Springer // 999ef9eebSMatthias Springer // This file implements target-independent rewrites as 1->N patterns. 1099ef9eebSMatthias Springer // 1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===// 1299ef9eebSMatthias Springer 139b5a3d14SMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 149b5a3d14SMatthias Springer 1507677113SJakub Kuderski #include <cassert> 16f80a976aSJakub Kuderski #include <cstdint> 17fb7ef637SJakub Kuderski #include <functional> 185382d288SKazu Hirata #include <optional> 1999ef9eebSMatthias Springer #include <type_traits> 2099ef9eebSMatthias Springer 2199ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 22abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 23abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 2499ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 2599ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 268b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 27f80a976aSJakub Kuderski #include "mlir/Dialect/Tensor/IR/Tensor.h" 28f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h" 2999ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 304758e916SOleg Shyshkov #include "mlir/Dialect/Vector/IR/VectorOps.h" 31fb7ef637SJakub Kuderski #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 329b5a3d14SMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 3383df43f3SAlex Zinenko #include "mlir/IR/BuiltinAttributeInterfaces.h" 344db65e27SLei Zhang #include "mlir/IR/BuiltinTypes.h" 3599ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h" 36f80a976aSJakub Kuderski #include "mlir/IR/Location.h" 3799ef9eebSMatthias Springer #include "mlir/IR/Matchers.h" 3899ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h" 39f80a976aSJakub Kuderski #include "mlir/IR/TypeUtilities.h" 4099ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h" 4199ef9eebSMatthias Springer 4299ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h" 4399ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h" 4499ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h" 4599ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h" 4699ef9eebSMatthias Springer #include "llvm/Support/Debug.h" 4707677113SJakub Kuderski #include "llvm/Support/FormatVariadic.h" 4899ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h" 4999ef9eebSMatthias Springer 5099ef9eebSMatthias Springer #define DEBUG_TYPE "vector-to-vector" 5199ef9eebSMatthias Springer 5299ef9eebSMatthias Springer using namespace mlir; 5399ef9eebSMatthias Springer using namespace mlir::vector; 5499ef9eebSMatthias Springer 5599ef9eebSMatthias Springer template <typename IntType> 567a69a9d7SNicolas Vasilache static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { 5799ef9eebSMatthias Springer return llvm::to_vector<4>(llvm::map_range( 5899ef9eebSMatthias Springer arrayAttr.getAsRange<IntegerAttr>(), 5999ef9eebSMatthias Springer [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 6099ef9eebSMatthias Springer } 6199ef9eebSMatthias Springer 622bc4c3e9SNicolas Vasilache // Helper to find an index in an affine map. 632bc4c3e9SNicolas Vasilache static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 642bc4c3e9SNicolas Vasilache for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 652bc4c3e9SNicolas Vasilache int64_t idx = map.getDimPosition(i); 662bc4c3e9SNicolas Vasilache if (idx == index) 6789aaa2d0SThomas Raoux return i; 6889aaa2d0SThomas Raoux } 691a36588eSKazu Hirata return std::nullopt; 7089aaa2d0SThomas Raoux } 7189aaa2d0SThomas Raoux 7299ef9eebSMatthias Springer namespace { 7399ef9eebSMatthias Springer 7499ef9eebSMatthias Springer /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 7599ef9eebSMatthias Springer // 7699ef9eebSMatthias Springer // Example: 7799ef9eebSMatthias Springer // 7899ef9eebSMatthias Springer // The following MLIR with cancelling ShapeCastOps: 7999ef9eebSMatthias Springer // 8099ef9eebSMatthias Springer // %0 = source : vector<5x4x2xf32> 8199ef9eebSMatthias Springer // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 8299ef9eebSMatthias Springer // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 8399ef9eebSMatthias Springer // %3 = user %2 : vector<5x4x2xf32> 8499ef9eebSMatthias Springer // 8599ef9eebSMatthias Springer // Should canonicalize to the following: 8699ef9eebSMatthias Springer // 8799ef9eebSMatthias Springer // %0 = source : vector<5x4x2xf32> 8899ef9eebSMatthias Springer // %1 = user %0 : vector<5x4x2xf32> 8999ef9eebSMatthias Springer // 9099ef9eebSMatthias Springer struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 9127cc31b6SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 9299ef9eebSMatthias Springer 9399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 9499ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 9599ef9eebSMatthias Springer // Check if 'shapeCastOp' has vector source/result type. 9699ef9eebSMatthias Springer auto sourceVectorType = 975550c821STres Popp dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType()); 9899ef9eebSMatthias Springer auto resultVectorType = 995550c821STres Popp dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType()); 10099ef9eebSMatthias Springer if (!sourceVectorType || !resultVectorType) 10199ef9eebSMatthias Springer return failure(); 10299ef9eebSMatthias Springer 10399ef9eebSMatthias Springer // Check if shape cast op source operand is also a shape cast op. 10499ef9eebSMatthias Springer auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 1057c38fd60SJacques Pienaar shapeCastOp.getSource().getDefiningOp()); 10699ef9eebSMatthias Springer if (!sourceShapeCastOp) 10799ef9eebSMatthias Springer return failure(); 10899ef9eebSMatthias Springer auto operandSourceVectorType = 1095550c821STres Popp cast<VectorType>(sourceShapeCastOp.getSource().getType()); 11099ef9eebSMatthias Springer auto operandResultVectorType = sourceShapeCastOp.getType(); 11199ef9eebSMatthias Springer 11299ef9eebSMatthias Springer // Check if shape cast operations invert each other. 11399ef9eebSMatthias Springer if (operandSourceVectorType != resultVectorType || 11499ef9eebSMatthias Springer operandResultVectorType != sourceVectorType) 11599ef9eebSMatthias Springer return failure(); 11699ef9eebSMatthias Springer 1177c38fd60SJacques Pienaar rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 11899ef9eebSMatthias Springer return success(); 11999ef9eebSMatthias Springer } 12099ef9eebSMatthias Springer }; 12199ef9eebSMatthias Springer 12299ef9eebSMatthias Springer /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 12399ef9eebSMatthias Springer /// Ex: 12499ef9eebSMatthias Springer /// ``` 12599ef9eebSMatthias Springer /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 12699ef9eebSMatthias Springer /// %1 = vector.multi_reduction add, %0 [1] 12799ef9eebSMatthias Springer /// : vector<8x32x16xf32> to vector<8x16xf32> 12899ef9eebSMatthias Springer /// ``` 12999ef9eebSMatthias Springer /// Gets converted to: 13099ef9eebSMatthias Springer /// ``` 13199ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [ 13299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 13399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 13499ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>], 13599ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"], 13699ef9eebSMatthias Springer /// kind = add} %0, %arg1, %cst_f0 13799ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 13899ef9eebSMatthias Springer /// ``` 13999ef9eebSMatthias Springer struct MultiReduceToContract 14099ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> { 14127cc31b6SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 14299ef9eebSMatthias Springer 14399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 14499ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 1457c38fd60SJacques Pienaar if (reduceOp.getKind() != vector::CombiningKind::ADD) 14699ef9eebSMatthias Springer return failure(); 1477c38fd60SJacques Pienaar Operation *mulOp = reduceOp.getSource().getDefiningOp(); 14899ef9eebSMatthias Springer if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 14999ef9eebSMatthias Springer return failure(); 15099ef9eebSMatthias Springer SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 15199ef9eebSMatthias Springer auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 15299ef9eebSMatthias Springer SmallVector<AffineExpr> exprs; 1534758e916SOleg Shyshkov SmallVector<vector::IteratorType> iteratorTypes; 15499ef9eebSMatthias Springer for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 15599ef9eebSMatthias Springer if (!isReduceDim.value()) { 1564758e916SOleg Shyshkov iteratorTypes.push_back(vector::IteratorType::parallel); 15799ef9eebSMatthias Springer exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 15899ef9eebSMatthias Springer } else { 1594758e916SOleg Shyshkov iteratorTypes.push_back(vector::IteratorType::reduction); 16099ef9eebSMatthias Springer } 16199ef9eebSMatthias Springer } 162fe8a62c4SUday Bondhugula auto dstMap = 163fe8a62c4SUday Bondhugula AffineMap::get(/*dimCount=*/reductionMask.size(), 164fe8a62c4SUday Bondhugula /*symbolCount=*/0, exprs, reduceOp.getContext()); 16599ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 166051b36baSThomas Raoux reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(), 16799ef9eebSMatthias Springer rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 1684758e916SOleg Shyshkov rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( 1694758e916SOleg Shyshkov iteratorTypes, [&](IteratorType t) -> mlir::Attribute { 1704758e916SOleg Shyshkov return IteratorTypeAttr::get(rewriter.getContext(), t); 1714758e916SOleg Shyshkov })))); 17299ef9eebSMatthias Springer return success(); 17399ef9eebSMatthias Springer } 17499ef9eebSMatthias Springer }; 17599ef9eebSMatthias Springer 176f0c93fd4SLei Zhang /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user. 17799ef9eebSMatthias Springer /// Ex: 17899ef9eebSMatthias Springer /// ``` 17999ef9eebSMatthias Springer /// %0 = vector.transpose %arg0, [2, 0, 1] 18099ef9eebSMatthias Springer /// : vector<32x16x8xf32> to vector<8x32x16xf32> 18199ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [ 18299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 18399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 18499ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>], 18599ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"], 18699ef9eebSMatthias Springer /// kind = add} %0, %arg1, %cst_f0 18799ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 18899ef9eebSMatthias Springer /// ``` 18999ef9eebSMatthias Springer /// Gets converted to: 19099ef9eebSMatthias Springer /// ``` 19199ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [ 19299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 19399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 19499ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>], 19599ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"], 19699ef9eebSMatthias Springer /// kind = add} %arg0, %arg1, %cst_f0 19799ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 19899ef9eebSMatthias Springer /// ``` 199f0c93fd4SLei Zhang struct CombineContractABTranspose final 20099ef9eebSMatthias Springer : public OpRewritePattern<vector::ContractionOp> { 20127cc31b6SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 20299ef9eebSMatthias Springer 20399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 20499ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 2057a69a9d7SNicolas Vasilache SmallVector<AffineMap> maps = 206d2c0572bSJacques Pienaar llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 2077c38fd60SJacques Pienaar Value lhs = contractOp.getLhs(); 2087c38fd60SJacques Pienaar Value rhs = contractOp.getRhs(); 20999ef9eebSMatthias Springer size_t index = 0; 21099ef9eebSMatthias Springer bool changed = false; 21199ef9eebSMatthias Springer for (Value *operand : {&lhs, &rhs}) { 21299ef9eebSMatthias Springer AffineMap &map = maps[index++]; 21399ef9eebSMatthias Springer auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 21499ef9eebSMatthias Springer if (!transposeOp) 21599ef9eebSMatthias Springer continue; 21699ef9eebSMatthias Springer AffineMap permutationMap = AffineMap::getPermutationMap( 21732c3decbSMatthias Springer transposeOp.getPermutation(), contractOp.getContext()); 21899ef9eebSMatthias Springer map = inversePermutation(permutationMap).compose(map); 2197c38fd60SJacques Pienaar *operand = transposeOp.getVector(); 22099ef9eebSMatthias Springer changed = true; 22199ef9eebSMatthias Springer } 22299ef9eebSMatthias Springer if (!changed) 22399ef9eebSMatthias Springer return failure(); 22499ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ContractionOp>( 2257c38fd60SJacques Pienaar contractOp, lhs, rhs, contractOp.getAcc(), 2267c38fd60SJacques Pienaar rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 22799ef9eebSMatthias Springer return success(); 22899ef9eebSMatthias Springer } 22999ef9eebSMatthias Springer }; 23099ef9eebSMatthias Springer 231f0c93fd4SLei Zhang /// Merges accumulator and result transposes into contract. 232f0c93fd4SLei Zhang /// 233f0c93fd4SLei Zhang /// For example: 234f0c93fd4SLei Zhang /// ```mlir 235f0c93fd4SLei Zhang /// %accT = vector.transpose %acc, [0, 2, 1] 236f0c93fd4SLei Zhang /// : vector<2x8x4xf32> to vector<2x4x8xf32> 237f0c93fd4SLei Zhang /// %contract = vector.contract { 238f0c93fd4SLei Zhang /// indexing_maps = [ 239f0c93fd4SLei Zhang /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, 240f0c93fd4SLei Zhang /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, 241f0c93fd4SLei Zhang /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 242f0c93fd4SLei Zhang /// ], 243f0c93fd4SLei Zhang /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], 244f0c93fd4SLei Zhang /// kind = #vector.kind<add> 245f0c93fd4SLei Zhang /// } %lhs, %rhs, %accT 246f0c93fd4SLei Zhang /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> 247f0c93fd4SLei Zhang /// %0 = vector.transpose %contract, [0, 2, 1] 248f0c93fd4SLei Zhang /// : vector<2x4x8xf32> to vector<2x8x4> 249f0c93fd4SLei Zhang /// ``` 250f0c93fd4SLei Zhang /// Becomes: 251f0c93fd4SLei Zhang /// ```mlir 252f0c93fd4SLei Zhang /// %0 = vector.contract { 253f0c93fd4SLei Zhang /// indexing_maps = [ 254f0c93fd4SLei Zhang /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, 255f0c93fd4SLei Zhang /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, 256f0c93fd4SLei Zhang /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> 257f0c93fd4SLei Zhang /// ], 258f0c93fd4SLei Zhang /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], 259f0c93fd4SLei Zhang /// kind = #vector.kind<add> 260f0c93fd4SLei Zhang /// } %lhs, %rhs, %acc 261f0c93fd4SLei Zhang /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32> 262f0c93fd4SLei Zhang /// ``` 263f0c93fd4SLei Zhang struct CombineContractResultTranspose final 264f0c93fd4SLei Zhang : public OpRewritePattern<vector::TransposeOp> { 265f0c93fd4SLei Zhang using OpRewritePattern::OpRewritePattern; 266f0c93fd4SLei Zhang 267f0c93fd4SLei Zhang LogicalResult matchAndRewrite(vector::TransposeOp resTOp, 268f0c93fd4SLei Zhang PatternRewriter &rewriter) const override { 269f0c93fd4SLei Zhang auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>(); 270f0c93fd4SLei Zhang if (!contractOp || !contractOp->hasOneUse()) 271f0c93fd4SLei Zhang return failure(); 272f0c93fd4SLei Zhang 273f0c93fd4SLei Zhang auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>(); 274f0c93fd4SLei Zhang if (!accTOp) 275f0c93fd4SLei Zhang return failure(); 276f0c93fd4SLei Zhang 277f0c93fd4SLei Zhang MLIRContext *context = contractOp.getContext(); 278f0c93fd4SLei Zhang auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray()); 279f0c93fd4SLei Zhang AffineMap contractMap = maps.back(); 280f0c93fd4SLei Zhang 281f0c93fd4SLei Zhang // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B. 282f0c93fd4SLei Zhang // To index into A in contract, we need revert(f)(g(C)) -> A. 28332c3decbSMatthias Springer auto accTMap = 28432c3decbSMatthias Springer AffineMap::getPermutationMap(accTOp.getPermutation(), context); 285f0c93fd4SLei Zhang 286f0c93fd4SLei Zhang // Contract performs g(C) -> D. Result transpose performs h(D) -> E. 287f0c93fd4SLei Zhang // To index into E in contract, we need h(g(C)) -> E. 28832c3decbSMatthias Springer auto resTMap = 28932c3decbSMatthias Springer AffineMap::getPermutationMap(resTOp.getPermutation(), context); 290f0c93fd4SLei Zhang auto combinedResMap = resTMap.compose(contractMap); 291f0c93fd4SLei Zhang 292f0c93fd4SLei Zhang // The accumulator and result share the same indexing map. So they should be 293f0c93fd4SLei Zhang // the same to be able to merge. This means combinedResMap is the same as 294f0c93fd4SLei Zhang // inversePermutation(accTMap).compose(contractMap), which means 295f0c93fd4SLei Zhang if (inversePermutation(accTMap) != resTMap) 296f0c93fd4SLei Zhang return failure(); 297f0c93fd4SLei Zhang maps.back() = combinedResMap; 298f0c93fd4SLei Zhang 299f0c93fd4SLei Zhang rewriter.replaceOpWithNewOp<vector::ContractionOp>( 300f0c93fd4SLei Zhang resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(), 301f0c93fd4SLei Zhang rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 302f0c93fd4SLei Zhang return success(); 303f0c93fd4SLei Zhang } 304f0c93fd4SLei Zhang }; 305f0c93fd4SLei Zhang 30699ef9eebSMatthias Springer /// Merge BroadcastOp into ContractionOp user. 30799ef9eebSMatthias Springer /// Ex: 30899ef9eebSMatthias Springer /// ``` 30999ef9eebSMatthias Springer /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 31099ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [ 31199ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 31299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 31399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>], 31499ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"], 31599ef9eebSMatthias Springer /// kind = add} %0, %arg1, %cst_f0 31699ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 31799ef9eebSMatthias Springer /// ``` 31899ef9eebSMatthias Springer /// Gets converted to: 31999ef9eebSMatthias Springer /// ``` 32099ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [ 32199ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d1, d2)>, 32299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 32399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>], 32499ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"], 32599ef9eebSMatthias Springer /// kind = add} %arg0, %arg1, %cst_f0 32699ef9eebSMatthias Springer /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 32799ef9eebSMatthias Springer /// ``` 32899ef9eebSMatthias Springer struct CombineContractBroadcast 32999ef9eebSMatthias Springer : public OpRewritePattern<vector::ContractionOp> { 33027cc31b6SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 33199ef9eebSMatthias Springer 33299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 33399ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 3347a69a9d7SNicolas Vasilache SmallVector<AffineMap> maps = 335d2c0572bSJacques Pienaar llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 3367c38fd60SJacques Pienaar Value lhs = contractOp.getLhs(); 3377c38fd60SJacques Pienaar Value rhs = contractOp.getRhs(); 33899ef9eebSMatthias Springer size_t index = 0; 33999ef9eebSMatthias Springer bool changed = false; 34099ef9eebSMatthias Springer for (Value *operand : {&lhs, &rhs}) { 34199ef9eebSMatthias Springer AffineMap &map = maps[index++]; 34299ef9eebSMatthias Springer auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 34399ef9eebSMatthias Springer if (!broadcast) 34499ef9eebSMatthias Springer continue; 34599ef9eebSMatthias Springer // contractionOp can only take vector as operands. 3465550c821STres Popp auto srcType = dyn_cast<VectorType>(broadcast.getSourceType()); 347a1aad28dSLei Zhang if (!srcType || 348a1aad28dSLei Zhang srcType.getRank() == broadcast.getResultVectorType().getRank()) 34999ef9eebSMatthias Springer continue; 35099ef9eebSMatthias Springer int64_t rankDiff = 351a1aad28dSLei Zhang broadcast.getResultVectorType().getRank() - srcType.getRank(); 35299ef9eebSMatthias Springer bool innerDimBroadcast = false; 35399ef9eebSMatthias Springer SmallVector<AffineExpr> originalDims; 35499ef9eebSMatthias Springer for (const auto &dim : llvm::enumerate(srcType.getShape())) { 355a1aad28dSLei Zhang if (dim.value() != broadcast.getResultVectorType().getDimSize( 356a1aad28dSLei Zhang rankDiff + dim.index())) { 35799ef9eebSMatthias Springer innerDimBroadcast = true; 35899ef9eebSMatthias Springer break; 35999ef9eebSMatthias Springer } 36099ef9eebSMatthias Springer originalDims.push_back( 36199ef9eebSMatthias Springer rewriter.getAffineDimExpr(dim.index() + rankDiff)); 36299ef9eebSMatthias Springer } 36399ef9eebSMatthias Springer // Contract doesn't support inner dimension broadcast. Once this is 36499ef9eebSMatthias Springer // relaxed we can remove this case. 36599ef9eebSMatthias Springer if (innerDimBroadcast) 36699ef9eebSMatthias Springer continue; 367694ad3eaSBenoit Jacob 368694ad3eaSBenoit Jacob // It would be incorrect to fold a broadcast onto a reduction dimension 369694ad3eaSBenoit Jacob // of non-unit size. 370694ad3eaSBenoit Jacob bool nonUnitDimReductionBroadcast = false; 371694ad3eaSBenoit Jacob for (int64_t i = 0; i < rankDiff; ++i) { 372a1aad28dSLei Zhang if (broadcast.getResultVectorType().getDimSize(i) != 1 && 373694ad3eaSBenoit Jacob isReductionIterator(contractOp.getIteratorTypes() 374694ad3eaSBenoit Jacob .getValue()[map.getDimPosition(i)])) { 375694ad3eaSBenoit Jacob nonUnitDimReductionBroadcast = true; 376694ad3eaSBenoit Jacob break; 377694ad3eaSBenoit Jacob } 378694ad3eaSBenoit Jacob } 379694ad3eaSBenoit Jacob if (nonUnitDimReductionBroadcast) 380694ad3eaSBenoit Jacob continue; 381694ad3eaSBenoit Jacob 38299ef9eebSMatthias Springer AffineMap broadcastMap = 383a1aad28dSLei Zhang AffineMap::get(broadcast.getResultVectorType().getRank(), 0, 384a1aad28dSLei Zhang originalDims, contractOp.getContext()); 38599ef9eebSMatthias Springer map = broadcastMap.compose(map); 3867c38fd60SJacques Pienaar *operand = broadcast.getSource(); 38799ef9eebSMatthias Springer changed = true; 38899ef9eebSMatthias Springer } 389694ad3eaSBenoit Jacob 39099ef9eebSMatthias Springer if (!changed) 39199ef9eebSMatthias Springer return failure(); 392694ad3eaSBenoit Jacob 393694ad3eaSBenoit Jacob // Determine which dims are usused, now that the maps have been composed 394694ad3eaSBenoit Jacob // with the broadcast maps. 395c3839c0bSBenoit Jacob llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); 396694ad3eaSBenoit Jacob // Compress unused dims. 397694ad3eaSBenoit Jacob for (auto &m : maps) 398c3839c0bSBenoit Jacob m = compressDims(m, unusedDimsBitVector); 399694ad3eaSBenoit Jacob // Compute the combined iterators. 4007a69a9d7SNicolas Vasilache SmallVector<Attribute> iterators; 401c3839c0bSBenoit Jacob for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { 402c3839c0bSBenoit Jacob if (!unusedDimsBitVector.test(i)) 403694ad3eaSBenoit Jacob iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); 404694ad3eaSBenoit Jacob } 405f0c3fd18SBenoit Jacob // Check that compressing unused dims isn't removing all reduction dimension 406f0c3fd18SBenoit Jacob // pairs. For example, if the vector.contract had only one reduction 407694ad3eaSBenoit Jacob // iterator and that was a unit-dimension created by a broadcast, 408694ad3eaSBenoit Jacob // then we should bail here, otherwise we would create a contract without 409f0c3fd18SBenoit Jacob // a reduction dimension pair. 410f0c3fd18SBenoit Jacob bool hasReductionIteratorApplyingOnBothSides = false; 411f0c3fd18SBenoit Jacob for (unsigned i = 0; i < iterators.size(); ++i) { 412f0c3fd18SBenoit Jacob if (!isReductionIterator(iterators[i])) 413f0c3fd18SBenoit Jacob continue; 414f0c3fd18SBenoit Jacob if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { 415f0c3fd18SBenoit Jacob hasReductionIteratorApplyingOnBothSides = true; 416f0c3fd18SBenoit Jacob break; 417f0c3fd18SBenoit Jacob } 418f0c3fd18SBenoit Jacob } 419f0c3fd18SBenoit Jacob if (!hasReductionIteratorApplyingOnBothSides) 420694ad3eaSBenoit Jacob return failure(); 421f0c3fd18SBenoit Jacob 422c3839c0bSBenoit Jacob // If the compressed maps have a dimension that is not used by either LHS or 423c3839c0bSBenoit Jacob // RHS then the ContractionOp verifier would fail. 424c3839c0bSBenoit Jacob if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) 425c3839c0bSBenoit Jacob return failure(); 42699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ContractionOp>( 4277c38fd60SJacques Pienaar contractOp, lhs, rhs, contractOp.getAcc(), 428694ad3eaSBenoit Jacob rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); 42999ef9eebSMatthias Springer return success(); 43099ef9eebSMatthias Springer } 43199ef9eebSMatthias Springer }; 43299ef9eebSMatthias Springer 4331538bd51SHanhan Wang /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 4341538bd51SHanhan Wang /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 4351538bd51SHanhan Wang /// casting ops are around these operations. 4361538bd51SHanhan Wang /// Ex: 4371538bd51SHanhan Wang /// ``` 4381538bd51SHanhan Wang /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 4391538bd51SHanhan Wang /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 4401538bd51SHanhan Wang /// ``` 4411538bd51SHanhan Wang /// Gets converted to: 4421538bd51SHanhan Wang /// ``` 4431538bd51SHanhan Wang /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 4441538bd51SHanhan Wang /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 4451538bd51SHanhan Wang /// ``` 4461538bd51SHanhan Wang struct ReorderCastOpsOnBroadcast 4471538bd51SHanhan Wang : public OpInterfaceRewritePattern<CastOpInterface> { 4481538bd51SHanhan Wang using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 4491538bd51SHanhan Wang 4501538bd51SHanhan Wang LogicalResult matchAndRewrite(CastOpInterface op, 4511538bd51SHanhan Wang PatternRewriter &rewriter) const override { 4521538bd51SHanhan Wang if (op->getNumOperands() != 1) 4531538bd51SHanhan Wang return failure(); 4541538bd51SHanhan Wang auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 4551538bd51SHanhan Wang if (!bcastOp) 4561538bd51SHanhan Wang return failure(); 4571538bd51SHanhan Wang 4581538bd51SHanhan Wang Type castResTy = getElementTypeOrSelf(op->getResult(0)); 4595550c821STres Popp if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType())) 4604b2ba5a6SAndrzej Warzyński castResTy = vecTy.clone(castResTy); 46135f48edbSMehdi Amini auto *castOp = 4627c38fd60SJacques Pienaar rewriter.create(op->getLoc(), op->getName().getIdentifier(), 4637c38fd60SJacques Pienaar bcastOp.getSource(), castResTy, op->getAttrs()); 4641538bd51SHanhan Wang rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 4651538bd51SHanhan Wang op, op->getResult(0).getType(), castOp->getResult(0)); 4661538bd51SHanhan Wang return success(); 4671538bd51SHanhan Wang } 4681538bd51SHanhan Wang }; 4691538bd51SHanhan Wang 4704db65e27SLei Zhang /// Reorders elementwise(transpose) to transpose(elementwise). This makes 4714db65e27SLei Zhang /// transpose ops and contraction ops closer, which kicks in 472f0c93fd4SLei Zhang /// CombineContractABTranspose pattern when elementwise ops are between these 4734db65e27SLei Zhang /// operations. Ex: 4741538bd51SHanhan Wang /// ``` 4754db65e27SLei Zhang /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 4764db65e27SLei Zhang /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 4774db65e27SLei Zhang /// %r = arith.addf %at, %bt : vector<2x4xf32> 4781538bd51SHanhan Wang /// ``` 4791538bd51SHanhan Wang /// Gets converted to: 4801538bd51SHanhan Wang /// ``` 4814db65e27SLei Zhang /// %0 = arith.addf %a, %b : vector<4x2xf32> 4824db65e27SLei Zhang /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 4831538bd51SHanhan Wang /// ``` 4844db65e27SLei Zhang struct ReorderElementwiseOpsOnTranspose final 4854db65e27SLei Zhang : public OpTraitRewritePattern<OpTrait::Elementwise> { 4864db65e27SLei Zhang using OpTraitRewritePattern::OpTraitRewritePattern; 4874db65e27SLei Zhang LogicalResult matchAndRewrite(Operation *op, 4881538bd51SHanhan Wang PatternRewriter &rewriter) const override { 4894db65e27SLei Zhang if (op->getNumResults() != 1 || op->getNumRegions() != 0) 4901538bd51SHanhan Wang return failure(); 4911538bd51SHanhan Wang 4924db65e27SLei Zhang // Make sure all operands are transpose/constant ops and collect their 4934db65e27SLei Zhang // transposition maps. 49432c3decbSMatthias Springer SmallVector<ArrayRef<int64_t>> transposeMaps; 4954db65e27SLei Zhang transposeMaps.reserve(op->getNumOperands()); 4964db65e27SLei Zhang // Record the initial type before transposition. We'll use its shape later. 4974db65e27SLei Zhang // Any type will do here as we will check all transpose maps are the same. 4984db65e27SLei Zhang VectorType srcType; 4994db65e27SLei Zhang for (Value operand : op->getOperands()) { 5004db65e27SLei Zhang auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 5014db65e27SLei Zhang if (transposeOp) { 50232c3decbSMatthias Springer transposeMaps.push_back(transposeOp.getPermutation()); 503a1aad28dSLei Zhang srcType = transposeOp.getSourceVectorType(); 5044db65e27SLei Zhang } else if (!matchPattern(operand, m_Constant())) { 5054db65e27SLei Zhang return failure(); 5064db65e27SLei Zhang } 5074db65e27SLei Zhang } 5084db65e27SLei Zhang if (transposeMaps.empty()) 5094db65e27SLei Zhang return failure(); 5104db65e27SLei Zhang // This is an elementwise op, so all transposed operands should have the 5114db65e27SLei Zhang // same type. We need to additionally check that all transposes uses the 5124db65e27SLei Zhang // same map. 5136fa87ec1SJakub Kuderski if (!llvm::all_equal(transposeMaps)) 5144db65e27SLei Zhang return rewriter.notifyMatchFailure(op, "different transpose map"); 5154db65e27SLei Zhang 5167a69a9d7SNicolas Vasilache SmallVector<Value> srcValues; 5174db65e27SLei Zhang srcValues.reserve(op->getNumOperands()); 5184db65e27SLei Zhang 5194db65e27SLei Zhang // If there are constant operands, we need to insert inverse transposes for 5204db65e27SLei Zhang // them. Calculate the inverse order first. 52132c3decbSMatthias Springer auto order = transposeMaps.front(); 5224db65e27SLei Zhang SmallVector<int64_t> invOrder(order.size()); 5234db65e27SLei Zhang for (int i = 0, e = order.size(); i < e; ++i) 5244db65e27SLei Zhang invOrder[order[i]] = i; 5254db65e27SLei Zhang 5264db65e27SLei Zhang for (Value operand : op->getOperands()) { 5274db65e27SLei Zhang auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 5284db65e27SLei Zhang if (transposeOp) { 5294db65e27SLei Zhang srcValues.push_back(transposeOp.getVector()); 5304db65e27SLei Zhang } else { 5314db65e27SLei Zhang // This is a constant. Create a reverse transpose op for it. 5324b2ba5a6SAndrzej Warzyński auto vectorType = 5334b2ba5a6SAndrzej Warzyński srcType.clone(cast<VectorType>(operand.getType()).getElementType()); 5344db65e27SLei Zhang srcValues.push_back(rewriter.create<vector::TransposeOp>( 53532c3decbSMatthias Springer operand.getLoc(), vectorType, operand, invOrder)); 5364db65e27SLei Zhang } 5374db65e27SLei Zhang } 5384db65e27SLei Zhang 5394b2ba5a6SAndrzej Warzyński auto vectorType = srcType.clone( 5405550c821STres Popp cast<VectorType>(op->getResultTypes()[0]).getElementType()); 5414db65e27SLei Zhang Operation *elementwiseOp = 5424db65e27SLei Zhang rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 5434db65e27SLei Zhang vectorType, op->getAttrs()); 5441538bd51SHanhan Wang rewriter.replaceOpWithNewOp<vector::TransposeOp>( 5454db65e27SLei Zhang op, op->getResultTypes()[0], elementwiseOp->getResult(0), 5464db65e27SLei Zhang transposeMaps.front()); 5471538bd51SHanhan Wang return success(); 5481538bd51SHanhan Wang } 5491538bd51SHanhan Wang }; 5501538bd51SHanhan Wang 55199ef9eebSMatthias Springer // Returns the values in `arrayAttr` as an integer vector. 5527a69a9d7SNicolas Vasilache static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { 55399ef9eebSMatthias Springer return llvm::to_vector<4>( 55499ef9eebSMatthias Springer llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 55599ef9eebSMatthias Springer [](IntegerAttr attr) { return attr.getInt(); })); 55699ef9eebSMatthias Springer } 55799ef9eebSMatthias Springer 55899ef9eebSMatthias Springer // Shuffles vector.bitcast op after vector.extract op. 55999ef9eebSMatthias Springer // 56099ef9eebSMatthias Springer // This transforms IR like: 56199ef9eebSMatthias Springer // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 5629816edc9SCullen Rhodes // %1 = vector.extract %0[3] : f16 from vector<8xf16> 56399ef9eebSMatthias Springer // Into: 5649816edc9SCullen Rhodes // %0 = vector.extract %src[1] : f32 from vector<4xf32> 56599ef9eebSMatthias Springer // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 5669816edc9SCullen Rhodes // %2 = vector.extract %1[1] : f16 from vector<2xf16> 56799ef9eebSMatthias Springer struct BubbleDownVectorBitCastForExtract 56899ef9eebSMatthias Springer : public OpRewritePattern<vector::ExtractOp> { 56999ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 57099ef9eebSMatthias Springer 57199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 57299ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 57399ef9eebSMatthias Springer // Only support extracting scalars for now. 574a1aad28dSLei Zhang if (extractOp.getSourceVectorType().getRank() != 1) 57599ef9eebSMatthias Springer return failure(); 57699ef9eebSMatthias Springer 5777c38fd60SJacques Pienaar auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 57899ef9eebSMatthias Springer if (!castOp) 57999ef9eebSMatthias Springer return failure(); 58099ef9eebSMatthias Springer 58199ef9eebSMatthias Springer VectorType castSrcType = castOp.getSourceVectorType(); 58299ef9eebSMatthias Springer VectorType castDstType = castOp.getResultVectorType(); 58399ef9eebSMatthias Springer assert(castSrcType.getRank() == castDstType.getRank()); 58499ef9eebSMatthias Springer 58599ef9eebSMatthias Springer // Fail to match if we only have one element in the cast op source. 58699ef9eebSMatthias Springer // This is to avoid infinite loop given that this pattern can generate 58799ef9eebSMatthias Springer // such cases. 58899ef9eebSMatthias Springer if (castSrcType.getNumElements() == 1) 58999ef9eebSMatthias Springer return failure(); 59099ef9eebSMatthias Springer 59199ef9eebSMatthias Springer // Only support casting to a larger number of elements or now. 59299ef9eebSMatthias Springer // E.g., vector<4xf32> -> vector<8xf16>. 59399ef9eebSMatthias Springer if (castSrcType.getNumElements() > castDstType.getNumElements()) 59499ef9eebSMatthias Springer return failure(); 59599ef9eebSMatthias Springer 59699ef9eebSMatthias Springer unsigned expandRatio = 59799ef9eebSMatthias Springer castDstType.getNumElements() / castSrcType.getNumElements(); 59899ef9eebSMatthias Springer 5996626ed6fSlialan // Get the first element of the mixed position as integer. 6006626ed6fSlialan auto mixedPos = extractOp.getMixedPosition(); 6016e41483bSKazu Hirata if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0])) 6026626ed6fSlialan return failure(); 6036e41483bSKazu Hirata uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt(); 60499ef9eebSMatthias Springer 60599ef9eebSMatthias Springer // Get the single scalar (as a vector) in the source value that packs the 60699ef9eebSMatthias Springer // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 60798f6289aSDiego Caballero Location loc = extractOp.getLoc(); 60899ef9eebSMatthias Springer Value packedValue = rewriter.create<vector::ExtractOp>( 60998f6289aSDiego Caballero loc, castOp.getSource(), index / expandRatio); 61098f6289aSDiego Caballero Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); 61198f6289aSDiego Caballero Value zero = rewriter.create<arith::ConstantOp>( 61298f6289aSDiego Caballero loc, packedVecType, rewriter.getZeroAttr(packedVecType)); 61398f6289aSDiego Caballero packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero, 61498f6289aSDiego Caballero /*position=*/0); 61599ef9eebSMatthias Springer 61699ef9eebSMatthias Springer // Cast it to a vector with the desired scalar's type. 61799ef9eebSMatthias Springer // E.g. f32 -> vector<2xf16> 61899ef9eebSMatthias Springer VectorType packedType = 61999ef9eebSMatthias Springer VectorType::get({expandRatio}, castDstType.getElementType()); 62098f6289aSDiego Caballero Value castedValue = 62198f6289aSDiego Caballero rewriter.create<vector::BitCastOp>(loc, packedType, packedValue); 62299ef9eebSMatthias Springer 62399ef9eebSMatthias Springer // Finally extract the desired scalar. 62498f6289aSDiego Caballero rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue, 62598f6289aSDiego Caballero index % expandRatio); 62699ef9eebSMatthias Springer return success(); 62799ef9eebSMatthias Springer } 62899ef9eebSMatthias Springer }; 62999ef9eebSMatthias Springer 63099ef9eebSMatthias Springer // Shuffles vector.bitcast op after vector.extract_strided_slice op. 63199ef9eebSMatthias Springer // 63299ef9eebSMatthias Springer // This transforms IR like: 63399ef9eebSMatthias Springer // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 63499ef9eebSMatthias Springer // %0 = vector.extract_strided_slice %cast { 63599ef9eebSMatthias Springer // offsets = [4], sizes = [4], strides = [1] 63699ef9eebSMatthias Springer // } : vector<8xf16> to vector<4xf16> 63799ef9eebSMatthias Springer // Into: 63899ef9eebSMatthias Springer // %0 = vector.extract_strided_slice %src { 63999ef9eebSMatthias Springer // offsets = [2], sizes = [2], strides = [1] 64099ef9eebSMatthias Springer // } : vector<4xf32> to vector<2xf32> 64199ef9eebSMatthias Springer // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 64299ef9eebSMatthias Springer struct BubbleDownBitCastForStridedSliceExtract 64399ef9eebSMatthias Springer : public OpRewritePattern<vector::ExtractStridedSliceOp> { 64499ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 64599ef9eebSMatthias Springer 64699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 64799ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 6487c38fd60SJacques Pienaar auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 64999ef9eebSMatthias Springer if (!castOp) 65099ef9eebSMatthias Springer return failure(); 65199ef9eebSMatthias Springer 65299ef9eebSMatthias Springer VectorType castSrcType = castOp.getSourceVectorType(); 65399ef9eebSMatthias Springer VectorType castDstType = castOp.getResultVectorType(); 65499ef9eebSMatthias Springer assert(castSrcType.getRank() == castDstType.getRank()); 65599ef9eebSMatthias Springer 65699ef9eebSMatthias Springer int64_t castSrcLastDim = castSrcType.getShape().back(); 65799ef9eebSMatthias Springer int64_t castDstLastDim = castDstType.getShape().back(); 65899ef9eebSMatthias Springer // Require casting to more elements for now; other cases to be implemented. 65999ef9eebSMatthias Springer if (castSrcLastDim > castDstLastDim) 66099ef9eebSMatthias Springer return failure(); 66199ef9eebSMatthias Springer 66299ef9eebSMatthias Springer // Only accept all one strides for now. 6637c38fd60SJacques Pienaar if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 6649e5d2495SKazu Hirata [](const APInt &val) { return !val.isOne(); })) 66599ef9eebSMatthias Springer return failure(); 66699ef9eebSMatthias Springer 667a1aad28dSLei Zhang unsigned rank = extractOp.getSourceVectorType().getRank(); 66899ef9eebSMatthias Springer assert(castDstLastDim % castSrcLastDim == 0); 66999ef9eebSMatthias Springer int64_t expandRatio = castDstLastDim / castSrcLastDim; 67099ef9eebSMatthias Springer 67199ef9eebSMatthias Springer // If we have a less number of offsets than the rank, then implicitly we 67299ef9eebSMatthias Springer // are selecting the full range for the last bitcasted dimension; other 67399ef9eebSMatthias Springer // dimensions aren't affected. Otherwise, we need to scale down the last 67499ef9eebSMatthias Springer // dimension's offset given we are extracting from less elements now. 6757c38fd60SJacques Pienaar ArrayAttr newOffsets = extractOp.getOffsets(); 67699ef9eebSMatthias Springer if (newOffsets.size() == rank) { 6777a69a9d7SNicolas Vasilache SmallVector<int64_t> offsets = getIntValueVector(newOffsets); 67899ef9eebSMatthias Springer if (offsets.back() % expandRatio != 0) 67999ef9eebSMatthias Springer return failure(); 68099ef9eebSMatthias Springer offsets.back() = offsets.back() / expandRatio; 68199ef9eebSMatthias Springer newOffsets = rewriter.getI64ArrayAttr(offsets); 68299ef9eebSMatthias Springer } 68399ef9eebSMatthias Springer 68499ef9eebSMatthias Springer // Similarly for sizes. 6857c38fd60SJacques Pienaar ArrayAttr newSizes = extractOp.getSizes(); 68699ef9eebSMatthias Springer if (newSizes.size() == rank) { 6877a69a9d7SNicolas Vasilache SmallVector<int64_t> sizes = getIntValueVector(newSizes); 68899ef9eebSMatthias Springer if (sizes.back() % expandRatio != 0) 68999ef9eebSMatthias Springer return failure(); 69099ef9eebSMatthias Springer sizes.back() = sizes.back() / expandRatio; 69199ef9eebSMatthias Springer newSizes = rewriter.getI64ArrayAttr(sizes); 69299ef9eebSMatthias Springer } 69399ef9eebSMatthias Springer 6947a69a9d7SNicolas Vasilache SmallVector<int64_t> dims = 6955550c821STres Popp llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape()); 69699ef9eebSMatthias Springer dims.back() = dims.back() / expandRatio; 69799ef9eebSMatthias Springer VectorType newExtractType = 69899ef9eebSMatthias Springer VectorType::get(dims, castSrcType.getElementType()); 69999ef9eebSMatthias Springer 70099ef9eebSMatthias Springer auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 7017c38fd60SJacques Pienaar extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 7027c38fd60SJacques Pienaar newSizes, extractOp.getStrides()); 70399ef9eebSMatthias Springer 70499ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BitCastOp>( 70599ef9eebSMatthias Springer extractOp, extractOp.getType(), newExtractOp); 70699ef9eebSMatthias Springer 70799ef9eebSMatthias Springer return success(); 70899ef9eebSMatthias Springer } 70999ef9eebSMatthias Springer }; 71099ef9eebSMatthias Springer 71199ef9eebSMatthias Springer // Shuffles vector.bitcast op before vector.insert_strided_slice op. 71299ef9eebSMatthias Springer // 71399ef9eebSMatthias Springer // This transforms IR like: 7144623c114SDiego Caballero // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4> 7154623c114SDiego Caballero // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8> 7164623c114SDiego Caballero // Into: 7174623c114SDiego Caballero // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8> 7184623c114SDiego Caballero // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8> 7194623c114SDiego Caballero // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> 7204623c114SDiego Caballero // 7214623c114SDiego Caballero struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { 7224623c114SDiego Caballero using OpRewritePattern::OpRewritePattern; 7234623c114SDiego Caballero 7244623c114SDiego Caballero LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 7254623c114SDiego Caballero PatternRewriter &rewriter) const override { 7264623c114SDiego Caballero VectorType castSrcType = bitcastOp.getSourceVectorType(); 7274623c114SDiego Caballero VectorType castDstType = bitcastOp.getResultVectorType(); 7284623c114SDiego Caballero 7294623c114SDiego Caballero // 0-D and scalable vectors are not supported yet. 7304623c114SDiego Caballero if (castSrcType.getRank() == 0 || castSrcType.isScalable() || 7314623c114SDiego Caballero castDstType.isScalable()) 7324623c114SDiego Caballero return failure(); 7334623c114SDiego Caballero 7344623c114SDiego Caballero int64_t castSrcLastDim = castSrcType.getShape().back(); 7354623c114SDiego Caballero int64_t castDstLastDim = castDstType.getShape().back(); 7364623c114SDiego Caballero bool isNumElemsShrink = castSrcLastDim >= castDstLastDim; 7374623c114SDiego Caballero int64_t ratio; 7384623c114SDiego Caballero if (isNumElemsShrink) { 7394623c114SDiego Caballero assert(castSrcLastDim % castDstLastDim == 0); 7404623c114SDiego Caballero ratio = castSrcLastDim / castDstLastDim; 7414623c114SDiego Caballero } else { 7424623c114SDiego Caballero assert(castDstLastDim % castSrcLastDim == 0); 7434623c114SDiego Caballero ratio = castDstLastDim / castSrcLastDim; 7444623c114SDiego Caballero } 7454623c114SDiego Caballero 7464623c114SDiego Caballero auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>(); 7474623c114SDiego Caballero if (!insertOp) 7484623c114SDiego Caballero return failure(); 7494623c114SDiego Caballero 7504623c114SDiego Caballero // Only vector sources are supported for now. 7514623c114SDiego Caballero auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType()); 7524623c114SDiego Caballero if (!insertSrcType) 7534623c114SDiego Caballero return failure(); 7544623c114SDiego Caballero 7554623c114SDiego Caballero // Bitcast the source. 7564623c114SDiego Caballero SmallVector<int64_t> srcDims(insertSrcType.getShape()); 7574623c114SDiego Caballero srcDims.back() = 7584623c114SDiego Caballero isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio; 7594623c114SDiego Caballero VectorType newCastSrcType = 7604623c114SDiego Caballero VectorType::get(srcDims, castDstType.getElementType()); 7614623c114SDiego Caballero auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 7624623c114SDiego Caballero bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 7634623c114SDiego Caballero 7644623c114SDiego Caballero SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape()); 7654623c114SDiego Caballero dstDims.back() = 7664623c114SDiego Caballero isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio; 7674623c114SDiego Caballero VectorType newCastDstType = 7684623c114SDiego Caballero VectorType::get(dstDims, castDstType.getElementType()); 7694623c114SDiego Caballero 7704623c114SDiego Caballero // Bitcast the destination. 7714623c114SDiego Caballero auto newCastDstOp = rewriter.create<vector::BitCastOp>( 7724623c114SDiego Caballero bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 7734623c114SDiego Caballero 7744623c114SDiego Caballero // Generate new insert. 7754623c114SDiego Caballero rewriter.replaceOpWithNewOp<vector::InsertOp>( 7764623c114SDiego Caballero bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition()); 7774623c114SDiego Caballero return success(); 7784623c114SDiego Caballero } 7794623c114SDiego Caballero }; 7804623c114SDiego Caballero 7814623c114SDiego Caballero // Shuffles vector.bitcast op before vector.insert_strided_slice op. 7824623c114SDiego Caballero // 7834623c114SDiego Caballero // This transforms IR like: 78499ef9eebSMatthias Springer // %0 = vector.insert_strided_slice %src, %dst { 78599ef9eebSMatthias Springer // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 78699ef9eebSMatthias Springer // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 78799ef9eebSMatthias Springer // Into: 78899ef9eebSMatthias Springer // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 78999ef9eebSMatthias Springer // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 79099ef9eebSMatthias Springer // %2 = vector.insert_strided_slice %src, %dst { 79199ef9eebSMatthias Springer // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 79299ef9eebSMatthias Springer struct BubbleUpBitCastForStridedSliceInsert 79399ef9eebSMatthias Springer : public OpRewritePattern<vector::BitCastOp> { 79499ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern; 79527cc31b6SNicolas Vasilache 79699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 79799ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 79899ef9eebSMatthias Springer VectorType castSrcType = bitcastOp.getSourceVectorType(); 79999ef9eebSMatthias Springer VectorType castDstType = bitcastOp.getResultVectorType(); 80099ef9eebSMatthias Springer assert(castSrcType.getRank() == castDstType.getRank()); 80128f9bfe4SXiang Li // Skip 0-D vector which will not from InsertStridedSliceOp. 80228f9bfe4SXiang Li if (castSrcType.getRank() == 0) 80328f9bfe4SXiang Li return failure(); 80499ef9eebSMatthias Springer 80599ef9eebSMatthias Springer int64_t castSrcLastDim = castSrcType.getShape().back(); 80699ef9eebSMatthias Springer int64_t castDstLastDim = castDstType.getShape().back(); 80799ef9eebSMatthias Springer // Require casting to less elements for now; other cases to be implemented. 80899ef9eebSMatthias Springer if (castSrcLastDim < castDstLastDim) 80999ef9eebSMatthias Springer return failure(); 81099ef9eebSMatthias Springer 81199ef9eebSMatthias Springer assert(castSrcLastDim % castDstLastDim == 0); 81299ef9eebSMatthias Springer int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 81399ef9eebSMatthias Springer 81499ef9eebSMatthias Springer auto insertOp = 8157c38fd60SJacques Pienaar bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 81699ef9eebSMatthias Springer if (!insertOp) 81799ef9eebSMatthias Springer return failure(); 81899ef9eebSMatthias Springer 81999ef9eebSMatthias Springer // Only accept all one strides for now. 8207c38fd60SJacques Pienaar if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 8219e5d2495SKazu Hirata [](const APInt &val) { return !val.isOne(); })) 82299ef9eebSMatthias Springer return failure(); 82399ef9eebSMatthias Springer 82499ef9eebSMatthias Springer unsigned rank = insertOp.getSourceVectorType().getRank(); 82599ef9eebSMatthias Springer // Require insert op to have the same rank for the source and destination 82699ef9eebSMatthias Springer // vector; other cases to be implemented. 82799ef9eebSMatthias Springer if (rank != insertOp.getDestVectorType().getRank()) 82899ef9eebSMatthias Springer return failure(); 82999ef9eebSMatthias Springer 830dc26c030Sstanley-nod // Requires that shape of insert op src is castable to dstType. 831dc26c030Sstanley-nod unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); 832dc26c030Sstanley-nod unsigned destinationWidth = 833dc26c030Sstanley-nod castDstType.getElementType().getIntOrFloatBitWidth(); 834dc26c030Sstanley-nod unsigned numElements = destinationWidth / sourceWidth; 835dc26c030Sstanley-nod if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) 836dc26c030Sstanley-nod return failure(); 837dc26c030Sstanley-nod 8387c38fd60SJacques Pienaar ArrayAttr newOffsets = insertOp.getOffsets(); 83999ef9eebSMatthias Springer assert(newOffsets.size() == rank); 8407a69a9d7SNicolas Vasilache SmallVector<int64_t> offsets = getIntValueVector(newOffsets); 84199ef9eebSMatthias Springer if (offsets.back() % shrinkRatio != 0) 84299ef9eebSMatthias Springer return failure(); 84399ef9eebSMatthias Springer offsets.back() = offsets.back() / shrinkRatio; 84499ef9eebSMatthias Springer newOffsets = rewriter.getI64ArrayAttr(offsets); 84599ef9eebSMatthias Springer 8467a69a9d7SNicolas Vasilache SmallVector<int64_t> srcDims = 84799ef9eebSMatthias Springer llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 84899ef9eebSMatthias Springer srcDims.back() = srcDims.back() / shrinkRatio; 84999ef9eebSMatthias Springer VectorType newCastSrcType = 85099ef9eebSMatthias Springer VectorType::get(srcDims, castDstType.getElementType()); 85199ef9eebSMatthias Springer 85299ef9eebSMatthias Springer auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 8537c38fd60SJacques Pienaar bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 85499ef9eebSMatthias Springer 8557a69a9d7SNicolas Vasilache SmallVector<int64_t> dstDims = 85699ef9eebSMatthias Springer llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 85799ef9eebSMatthias Springer dstDims.back() = dstDims.back() / shrinkRatio; 85899ef9eebSMatthias Springer VectorType newCastDstType = 85999ef9eebSMatthias Springer VectorType::get(dstDims, castDstType.getElementType()); 86099ef9eebSMatthias Springer 86199ef9eebSMatthias Springer auto newCastDstOp = rewriter.create<vector::BitCastOp>( 8627c38fd60SJacques Pienaar bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 86399ef9eebSMatthias Springer 86499ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 86599ef9eebSMatthias Springer bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 8667c38fd60SJacques Pienaar insertOp.getStrides()); 86799ef9eebSMatthias Springer 86899ef9eebSMatthias Springer return success(); 86999ef9eebSMatthias Springer } 87099ef9eebSMatthias Springer }; 87199ef9eebSMatthias Springer 872650f04feSQuinn Dawkins // Breaks down vector.bitcast op 873650f04feSQuinn Dawkins // 874650f04feSQuinn Dawkins // This transforms IR like: 875650f04feSQuinn Dawkins // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 876650f04feSQuinn Dawkins // Into: 877650f04feSQuinn Dawkins // %cst = vector.splat %c0_f32 : vector<4xf32> 878650f04feSQuinn Dawkins // %1 = vector.extract_strided_slice %0 { 879650f04feSQuinn Dawkins // offsets = [0], sizes = [4], strides = [1] 880650f04feSQuinn Dawkins // } : vector<8xf16> to vector<4xf16> 881650f04feSQuinn Dawkins // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32> 882650f04feSQuinn Dawkins // %4 = vector.insert_strided_slice %2, %cst { 883650f04feSQuinn Dawkins // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 884650f04feSQuinn Dawkins // %5 = vector.extract_strided_slice %0 { 885650f04feSQuinn Dawkins // offsets = [4], sizes = [4], strides = [1] 886650f04feSQuinn Dawkins // } : vector<8xf16> to vector<4xf16> 887650f04feSQuinn Dawkins // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32> 888650f04feSQuinn Dawkins // %7 = vector.insert_strided_slice %6, %cst { 889650f04feSQuinn Dawkins // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> 890650f04feSQuinn Dawkins struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { 891650f04feSQuinn Dawkins using OpRewritePattern::OpRewritePattern; 892650f04feSQuinn Dawkins 893650f04feSQuinn Dawkins public: 894650f04feSQuinn Dawkins BreakDownVectorBitCast(MLIRContext *context, 895650f04feSQuinn Dawkins std::function<bool(vector::BitCastOp)> controlFn, 896650f04feSQuinn Dawkins PatternBenefit benefit) 897650f04feSQuinn Dawkins : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 898650f04feSQuinn Dawkins 899650f04feSQuinn Dawkins LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 900650f04feSQuinn Dawkins PatternRewriter &rewriter) const override { 901650f04feSQuinn Dawkins 902650f04feSQuinn Dawkins if (controlFn && !controlFn(bitcastOp)) 903650f04feSQuinn Dawkins return failure(); 904650f04feSQuinn Dawkins 905650f04feSQuinn Dawkins VectorType castSrcType = bitcastOp.getSourceVectorType(); 906650f04feSQuinn Dawkins VectorType castDstType = bitcastOp.getResultVectorType(); 907650f04feSQuinn Dawkins assert(castSrcType.getRank() == castDstType.getRank()); 908650f04feSQuinn Dawkins 909d88293d8SAndrzej Warzyński // This transformation builds on top of 910d88293d8SAndrzej Warzyński // vector.{extract|insert}_strided_slice, which do not support 911d88293d8SAndrzej Warzyński // extracting/inserting "scallable sub-vectors". Bail out. 912d88293d8SAndrzej Warzyński if (castSrcType.isScalable()) 913d88293d8SAndrzej Warzyński return rewriter.notifyMatchFailure(bitcastOp, 914d88293d8SAndrzej Warzyński "Scalable vectors are not supported"); 915d88293d8SAndrzej Warzyński 916650f04feSQuinn Dawkins // Only support rank 1 case for now. 917650f04feSQuinn Dawkins if (castSrcType.getRank() != 1) 918650f04feSQuinn Dawkins return failure(); 919650f04feSQuinn Dawkins 920650f04feSQuinn Dawkins int64_t castSrcLastDim = castSrcType.getShape().back(); 921650f04feSQuinn Dawkins int64_t castDstLastDim = castDstType.getShape().back(); 922650f04feSQuinn Dawkins // Require casting to less elements for now; other cases to be implemented. 923650f04feSQuinn Dawkins if (castSrcLastDim < castDstLastDim) 924650f04feSQuinn Dawkins return failure(); 925650f04feSQuinn Dawkins 926650f04feSQuinn Dawkins assert(castSrcLastDim % castDstLastDim == 0); 927650f04feSQuinn Dawkins int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 928650f04feSQuinn Dawkins // Nothing to do if it is already bitcasting to a single element. 929650f04feSQuinn Dawkins if (castSrcLastDim == shrinkRatio) 930650f04feSQuinn Dawkins return failure(); 931650f04feSQuinn Dawkins 932650f04feSQuinn Dawkins Location loc = bitcastOp.getLoc(); 933650f04feSQuinn Dawkins Type elemType = castDstType.getElementType(); 934650f04feSQuinn Dawkins assert(elemType.isSignlessIntOrIndexOrFloat()); 935650f04feSQuinn Dawkins 936650f04feSQuinn Dawkins Value zero = rewriter.create<arith::ConstantOp>( 937650f04feSQuinn Dawkins loc, elemType, rewriter.getZeroAttr(elemType)); 938650f04feSQuinn Dawkins Value res = rewriter.create<SplatOp>(loc, castDstType, zero); 939650f04feSQuinn Dawkins 9409cbc1f29SHan-Chung Wang SmallVector<int64_t> sliceShape = {castDstLastDim}; 9419cbc1f29SHan-Chung Wang SmallVector<int64_t> strides = {1}; 942650f04feSQuinn Dawkins VectorType newCastDstType = 943650f04feSQuinn Dawkins VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio}, 944650f04feSQuinn Dawkins castDstType.getElementType()); 945650f04feSQuinn Dawkins 946650f04feSQuinn Dawkins for (int i = 0, e = shrinkRatio; i < e; ++i) { 947650f04feSQuinn Dawkins Value extracted = rewriter.create<ExtractStridedSliceOp>( 948650f04feSQuinn Dawkins loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim}, 949650f04feSQuinn Dawkins sliceShape, strides); 950650f04feSQuinn Dawkins Value bitcast = 951650f04feSQuinn Dawkins rewriter.create<BitCastOp>(loc, newCastDstType, extracted); 952650f04feSQuinn Dawkins res = rewriter.create<InsertStridedSliceOp>( 953650f04feSQuinn Dawkins loc, bitcast, res, 954650f04feSQuinn Dawkins ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides); 955650f04feSQuinn Dawkins } 956650f04feSQuinn Dawkins rewriter.replaceOp(bitcastOp, res); 957650f04feSQuinn Dawkins return success(); 958650f04feSQuinn Dawkins } 959650f04feSQuinn Dawkins 960650f04feSQuinn Dawkins private: 961650f04feSQuinn Dawkins std::function<bool(BitCastOp)> controlFn; 962650f04feSQuinn Dawkins }; 963650f04feSQuinn Dawkins 96459fbba94SAndrzej Warzyński /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: 9654d339ec9SAndrzej Warzynski /// ``` 9664d339ec9SAndrzej Warzynski /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> 9674d339ec9SAndrzej Warzynski /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> 9684d339ec9SAndrzej Warzynski /// %r = arith.addi %a, %b : vector<1x4xindex> 9694d339ec9SAndrzej Warzynski /// ``` 9704d339ec9SAndrzej Warzynski /// Gets converted to: 9714d339ec9SAndrzej Warzynski /// ``` 9724d339ec9SAndrzej Warzynski /// %r = arith.addi %arg0, %arg1 : index 9734d339ec9SAndrzej Warzynski /// %b = vector.broadcast %r : index to vector<1x4xindex> 9744d339ec9SAndrzej Warzynski /// ``` 97559fbba94SAndrzej Warzyński /// 97659fbba94SAndrzej Warzyński /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting 97759fbba94SAndrzej Warzyński /// ops. 9784d339ec9SAndrzej Warzynski struct ReorderElementwiseOpsOnBroadcast final 9794d339ec9SAndrzej Warzynski : public OpTraitRewritePattern<OpTrait::Elementwise> { 9804d339ec9SAndrzej Warzynski using OpTraitRewritePattern::OpTraitRewritePattern; 9814d339ec9SAndrzej Warzynski LogicalResult matchAndRewrite(Operation *op, 9824d339ec9SAndrzej Warzynski PatternRewriter &rewriter) const override { 9834d339ec9SAndrzej Warzynski if (op->getNumResults() != 1) 9844d339ec9SAndrzej Warzynski return failure(); 9854d339ec9SAndrzej Warzynski if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) 9864d339ec9SAndrzej Warzynski return failure(); 9874d339ec9SAndrzej Warzynski if (!OpTrait::hasElementwiseMappableTraits(op)) 988efe3db21SAndrzej Warzyński return rewriter.notifyMatchFailure( 989efe3db21SAndrzej Warzyński op, "Op doesn't have ElementwiseMappableTraits"); 990efe3db21SAndrzej Warzyński if (op->getNumOperands() == 0) 9914d339ec9SAndrzej Warzynski return failure(); 992efe3db21SAndrzej Warzyński if (op->getResults()[0].getType() != op->getOperand(0).getType()) 993efe3db21SAndrzej Warzyński return rewriter.notifyMatchFailure(op, 994efe3db21SAndrzej Warzyński "result and operand type mismatch"); 995f28f09dcSMaheshRavishankar if (isa<vector::FMAOp>(op)) { 996efe3db21SAndrzej Warzyński return rewriter.notifyMatchFailure( 997efe3db21SAndrzej Warzyński op, 998efe3db21SAndrzej Warzyński "Op only accepts vector types - not supported as broadcast source " 999efe3db21SAndrzej Warzyński "might be a scalar"); 1000f28f09dcSMaheshRavishankar } 10014d339ec9SAndrzej Warzynski 100259fbba94SAndrzej Warzyński // Get the type of the lhs operand 100359fbba94SAndrzej Warzyński auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); 100459fbba94SAndrzej Warzyński if (!lhsBcastOrSplat || 100559fbba94SAndrzej Warzyński !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) 10064d339ec9SAndrzej Warzynski return failure(); 100759fbba94SAndrzej Warzyński auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); 10084d339ec9SAndrzej Warzynski 100959fbba94SAndrzej Warzyński // Make sure that all operands are broadcast from identical types: 101059fbba94SAndrzej Warzyński // * scalar (`vector.broadcast` + `vector.splat`), or 101159fbba94SAndrzej Warzyński // * vector (`vector.broadcast`). 101259fbba94SAndrzej Warzyński // Otherwise the re-ordering wouldn't be safe. 101359fbba94SAndrzej Warzyński if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { 10144d339ec9SAndrzej Warzynski auto bcast = val.getDefiningOp<vector::BroadcastOp>(); 101559fbba94SAndrzej Warzyński if (bcast) 101659fbba94SAndrzej Warzyński return (bcast.getOperand().getType() == lhsBcastOrSplatType); 101759fbba94SAndrzej Warzyński auto splat = val.getDefiningOp<vector::SplatOp>(); 101859fbba94SAndrzej Warzyński if (splat) 101959fbba94SAndrzej Warzyński return (splat.getOperand().getType() == lhsBcastOrSplatType); 102059fbba94SAndrzej Warzyński return false; 10214d339ec9SAndrzej Warzynski })) { 10224d339ec9SAndrzej Warzynski return failure(); 10234d339ec9SAndrzej Warzynski } 10244d339ec9SAndrzej Warzynski 102559fbba94SAndrzej Warzyński // Collect the source values before broadcasting 10264d339ec9SAndrzej Warzynski SmallVector<Value> srcValues; 10274d339ec9SAndrzej Warzynski srcValues.reserve(op->getNumOperands()); 10284d339ec9SAndrzej Warzynski for (Value operand : op->getOperands()) { 102959fbba94SAndrzej Warzyński srcValues.push_back(operand.getDefiningOp()->getOperand(0)); 10304d339ec9SAndrzej Warzynski } 10314d339ec9SAndrzej Warzynski 103259fbba94SAndrzej Warzyński // Create the "elementwise" Op 10334d339ec9SAndrzej Warzynski Operation *elementwiseOp = 10344d339ec9SAndrzej Warzynski rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 103559fbba94SAndrzej Warzyński lhsBcastOrSplatType, op->getAttrs()); 10364d339ec9SAndrzej Warzynski 103759fbba94SAndrzej Warzyński // Replace the original Op with the elementwise Op 10384d339ec9SAndrzej Warzynski auto vectorType = op->getResultTypes()[0]; 10394d339ec9SAndrzej Warzynski rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 10404d339ec9SAndrzej Warzynski op, vectorType, elementwiseOp->getResults()); 10414d339ec9SAndrzej Warzynski 10424d339ec9SAndrzej Warzynski return success(); 10434d339ec9SAndrzej Warzynski } 10444d339ec9SAndrzej Warzynski }; 10454d339ec9SAndrzej Warzynski 104699ef9eebSMatthias Springer // Helper that returns a vector comparison that constructs a mask: 104799ef9eebSMatthias Springer // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 104899ef9eebSMatthias Springer // 104999ef9eebSMatthias Springer // If `dim == 0` then the result will be a 0-D vector. 105099ef9eebSMatthias Springer // 105199ef9eebSMatthias Springer // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 105299ef9eebSMatthias Springer // much more compact, IR for this operation, but LLVM eventually 105399ef9eebSMatthias Springer // generates more elaborate instructions for this intrinsic since it 105499ef9eebSMatthias Springer // is very conservative on the boundary conditions. 105599ef9eebSMatthias Springer static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 10567bc8ad51SJavier Setoain bool force32BitVectorIndices, int64_t dim, 105799ef9eebSMatthias Springer Value b, Value *off = nullptr) { 105899ef9eebSMatthias Springer auto loc = op->getLoc(); 105999ef9eebSMatthias Springer // If we can assume all indices fit in 32-bit, we perform the vector 106099ef9eebSMatthias Springer // comparison in 32-bit to get a higher degree of SIMD parallelism. 106199ef9eebSMatthias Springer // Otherwise we perform the vector comparison using 64-bit indices. 106299ef9eebSMatthias Springer Type idxType = 10637bc8ad51SJavier Setoain force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 106499ef9eebSMatthias Springer DenseIntElementsAttr indicesAttr; 10657bc8ad51SJavier Setoain if (dim == 0 && force32BitVectorIndices) { 106699ef9eebSMatthias Springer indicesAttr = DenseIntElementsAttr::get( 106799ef9eebSMatthias Springer VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 106899ef9eebSMatthias Springer } else if (dim == 0) { 106999ef9eebSMatthias Springer indicesAttr = DenseIntElementsAttr::get( 107099ef9eebSMatthias Springer VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 10717bc8ad51SJavier Setoain } else if (force32BitVectorIndices) { 107299ef9eebSMatthias Springer indicesAttr = rewriter.getI32VectorAttr( 107399ef9eebSMatthias Springer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 107499ef9eebSMatthias Springer } else { 107599ef9eebSMatthias Springer indicesAttr = rewriter.getI64VectorAttr( 107699ef9eebSMatthias Springer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 107799ef9eebSMatthias Springer } 107899ef9eebSMatthias Springer Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 107999ef9eebSMatthias Springer // Add in an offset if requested. 108099ef9eebSMatthias Springer if (off) { 1081a75a46dbSJavier Setoain Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 10826a8ba318SRiver Riddle Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 108399ef9eebSMatthias Springer indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 108499ef9eebSMatthias Springer } 108599ef9eebSMatthias Springer // Construct the vector comparison. 1086a75a46dbSJavier Setoain Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 10876a8ba318SRiver Riddle Value bounds = 10886a8ba318SRiver Riddle rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 108999ef9eebSMatthias Springer return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 109099ef9eebSMatthias Springer bounds); 109199ef9eebSMatthias Springer } 109299ef9eebSMatthias Springer 109399ef9eebSMatthias Springer template <typename ConcreteOp> 109499ef9eebSMatthias Springer struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 109599ef9eebSMatthias Springer public: 109627cc31b6SNicolas Vasilache explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt, 109727cc31b6SNicolas Vasilache PatternBenefit benefit = 1) 109827cc31b6SNicolas Vasilache : mlir::OpRewritePattern<ConcreteOp>(context, benefit), 10997bc8ad51SJavier Setoain force32BitVectorIndices(enableIndexOpt) {} 110099ef9eebSMatthias Springer 110199ef9eebSMatthias Springer LogicalResult matchAndRewrite(ConcreteOp xferOp, 110299ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 110399ef9eebSMatthias Springer if (!xferOp.hasOutOfBoundsDim()) 110499ef9eebSMatthias Springer return failure(); 110599ef9eebSMatthias Springer 1106be650de5SKazu Hirata if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty()) 110799ef9eebSMatthias Springer return failure(); 110899ef9eebSMatthias Springer 110999ef9eebSMatthias Springer Location loc = xferOp->getLoc(); 111099ef9eebSMatthias Springer VectorType vtp = xferOp.getVectorType(); 111199ef9eebSMatthias Springer 1112f2b89c7aSJavier Setoain // Create the in-bounds mask with all elements between [0 .. dim - offset) 1113f2b89c7aSJavier Setoain // set and [dim - offset .. vector_length) unset. 111499ef9eebSMatthias Springer // 111599ef9eebSMatthias Springer // TODO: when the leaf transfer rank is k > 1, we need the last `k` 111699ef9eebSMatthias Springer // dimensions here. 11177c38fd60SJacques Pienaar unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 11187c38fd60SJacques Pienaar Value off = xferOp.getIndices()[lastIndex]; 111999ef9eebSMatthias Springer Value dim = 11207c38fd60SJacques Pienaar vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 1121f2b89c7aSJavier Setoain Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 1122f2b89c7aSJavier Setoain Value mask = rewriter.create<vector::CreateMaskOp>( 1123f2b89c7aSJavier Setoain loc, 1124f2b89c7aSJavier Setoain VectorType::get(vtp.getShape(), rewriter.getI1Type(), 1125f22af204SAndrzej Warzynski vtp.getScalableDims()), 1126f2b89c7aSJavier Setoain b); 11277c38fd60SJacques Pienaar if (xferOp.getMask()) { 112899ef9eebSMatthias Springer // Intersect the in-bounds with the mask specified as an op parameter. 11297c38fd60SJacques Pienaar mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 113099ef9eebSMatthias Springer } 113199ef9eebSMatthias Springer 11325fcf907bSMatthias Springer rewriter.modifyOpInPlace(xferOp, [&]() { 11337c38fd60SJacques Pienaar xferOp.getMaskMutable().assign(mask); 11347c38fd60SJacques Pienaar xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 113599ef9eebSMatthias Springer }); 113699ef9eebSMatthias Springer 113799ef9eebSMatthias Springer return success(); 113899ef9eebSMatthias Springer } 113999ef9eebSMatthias Springer 114099ef9eebSMatthias Springer private: 11417bc8ad51SJavier Setoain const bool force32BitVectorIndices; 114299ef9eebSMatthias Springer }; 114399ef9eebSMatthias Springer 114499ef9eebSMatthias Springer /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 114599ef9eebSMatthias Springer class VectorCreateMaskOpConversion 114699ef9eebSMatthias Springer : public OpRewritePattern<vector::CreateMaskOp> { 114799ef9eebSMatthias Springer public: 114899ef9eebSMatthias Springer explicit VectorCreateMaskOpConversion(MLIRContext *context, 114927cc31b6SNicolas Vasilache bool enableIndexOpt, 115027cc31b6SNicolas Vasilache PatternBenefit benefit = 1) 115127cc31b6SNicolas Vasilache : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit), 11527bc8ad51SJavier Setoain force32BitVectorIndices(enableIndexOpt) {} 115399ef9eebSMatthias Springer 115499ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::CreateMaskOp op, 115599ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 115699ef9eebSMatthias Springer auto dstType = op.getType(); 11575550c821STres Popp if (cast<VectorType>(dstType).isScalable()) 1158a75a46dbSJavier Setoain return failure(); 115999ef9eebSMatthias Springer int64_t rank = dstType.getRank(); 116099ef9eebSMatthias Springer if (rank > 1) 116199ef9eebSMatthias Springer return failure(); 116299ef9eebSMatthias Springer rewriter.replaceOp( 11637bc8ad51SJavier Setoain op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 116499ef9eebSMatthias Springer rank == 0 ? 0 : dstType.getDimSize(0), 116599ef9eebSMatthias Springer op.getOperand(0))); 116699ef9eebSMatthias Springer return success(); 116799ef9eebSMatthias Springer } 116899ef9eebSMatthias Springer 116999ef9eebSMatthias Springer private: 11707bc8ad51SJavier Setoain const bool force32BitVectorIndices; 117199ef9eebSMatthias Springer }; 117299ef9eebSMatthias Springer 117315a08cf2SDiego Caballero /// Returns true if all the `i1` elements of `constantOp` are set to `value`. 117415a08cf2SDiego Caballero static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { 117515a08cf2SDiego Caballero auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue()); 117615a08cf2SDiego Caballero // TODO: Support non-dense constant. 117715a08cf2SDiego Caballero if (!denseAttr) 117815a08cf2SDiego Caballero return false; 117915a08cf2SDiego Caballero 118015a08cf2SDiego Caballero assert(denseAttr.getElementType().isInteger(1) && "Unexpected type"); 118115a08cf2SDiego Caballero return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value; 118215a08cf2SDiego Caballero } 118315a08cf2SDiego Caballero 118415a08cf2SDiego Caballero /// Folds a select operation between an all-true and all-false vector. For now, 118515a08cf2SDiego Caballero /// only single element vectors (i.e., vector<1xi1>) are supported. That is: 118615a08cf2SDiego Caballero /// 118715a08cf2SDiego Caballero /// %true = arith.constant dense<true> : vector<1xi1> 118815a08cf2SDiego Caballero /// %false = arith.constant dense<false> : vector<1xi1> 118915a08cf2SDiego Caballero /// %result = arith.select %cond, %true, %false : i1, vector<1xi1> 119015a08cf2SDiego Caballero /// => 119115a08cf2SDiego Caballero /// %result = vector.broadcast %cond : i1 to vector<1xi1> 119215a08cf2SDiego Caballero /// 119315a08cf2SDiego Caballero /// InstCombine seems to handle vectors with multiple elements but not the 119415a08cf2SDiego Caballero /// single element ones. 119515a08cf2SDiego Caballero struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { 119615a08cf2SDiego Caballero using OpRewritePattern<arith::SelectOp>::OpRewritePattern; 119715a08cf2SDiego Caballero 119815a08cf2SDiego Caballero LogicalResult matchAndRewrite(arith::SelectOp selectOp, 119915a08cf2SDiego Caballero PatternRewriter &rewriter) const override { 120015a08cf2SDiego Caballero auto vecType = dyn_cast<VectorType>(selectOp.getType()); 120115a08cf2SDiego Caballero if (!vecType || !vecType.getElementType().isInteger(1)) 120215a08cf2SDiego Caballero return failure(); 120315a08cf2SDiego Caballero 120415a08cf2SDiego Caballero // Only scalar conditions can be folded. 120515a08cf2SDiego Caballero Value cond = selectOp.getCondition(); 120615a08cf2SDiego Caballero if (isa<VectorType>(cond.getType())) 120715a08cf2SDiego Caballero return failure(); 120815a08cf2SDiego Caballero 120915a08cf2SDiego Caballero // TODO: Support n-D and scalable vectors. 121015a08cf2SDiego Caballero if (vecType.getRank() != 1 || vecType.isScalable()) 121115a08cf2SDiego Caballero return failure(); 121215a08cf2SDiego Caballero 121315a08cf2SDiego Caballero // TODO: Support vectors with multiple elements. 121415a08cf2SDiego Caballero if (vecType.getShape()[0] != 1) 121515a08cf2SDiego Caballero return failure(); 121615a08cf2SDiego Caballero 121715a08cf2SDiego Caballero auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>(); 121815a08cf2SDiego Caballero if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true)) 121915a08cf2SDiego Caballero return failure(); 122015a08cf2SDiego Caballero 122115a08cf2SDiego Caballero auto falseConst = 122215a08cf2SDiego Caballero selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>(); 122315a08cf2SDiego Caballero if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false)) 122415a08cf2SDiego Caballero return failure(); 122515a08cf2SDiego Caballero 122615a08cf2SDiego Caballero // Replace select with its condition broadcasted to single element vector. 122715a08cf2SDiego Caballero auto elemType = rewriter.getIntegerType(vecType.getNumElements()); 122815a08cf2SDiego Caballero auto bcastType = VectorType::get(/*shape=*/{1}, elemType); 122915a08cf2SDiego Caballero rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond); 123015a08cf2SDiego Caballero return success(); 123115a08cf2SDiego Caballero } 123215a08cf2SDiego Caballero }; 123315a08cf2SDiego Caballero 123412b676deSHan-Chung Wang /// Returns the number of dims can be folded away from transfer ops. It returns 123512b676deSHan-Chung Wang /// a failure if it can not determine the number of dims to be folded. 1236c65fb32dSAndrzej Warzyński /// 1237c65fb32dSAndrzej Warzyński /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and 1238c65fb32dSAndrzej Warzyński /// `vectorType` is vector<16x16x1x1xf32> 1239c65fb32dSAndrzej Warzyński /// (there two inner most dims can be dropped by memref.subview ops) 1240c65fb32dSAndrzej Warzyński /// 1241c65fb32dSAndrzej Warzyński /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with 1242c65fb32dSAndrzej Warzyński /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32> 1243c65fb32dSAndrzej Warzyński /// (only the inner most unit dim of `srcType` can be dropped) 1244c65fb32dSAndrzej Warzyński /// 1245c65fb32dSAndrzej Warzyński /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and 1246c65fb32dSAndrzej Warzyński /// `vectorType` is vector<16x16x1x[1]xf32> 1247c65fb32dSAndrzej Warzyński /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable 1248c65fb32dSAndrzej Warzyński /// unit") 124912b676deSHan-Chung Wang static FailureOr<size_t> 125012b676deSHan-Chung Wang getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { 125112b676deSHan-Chung Wang SmallVector<int64_t> srcStrides; 125212b676deSHan-Chung Wang int64_t srcOffset; 12536aaa8f25SMatthias Springer if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) 125412b676deSHan-Chung Wang return failure(); 125512b676deSHan-Chung Wang 1256e01ff823SBenjamin Maxwell auto isUnitDim = [](VectorType type, int dim) { 1257e01ff823SBenjamin Maxwell return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim]; 1258e01ff823SBenjamin Maxwell }; 1259e01ff823SBenjamin Maxwell 126012b676deSHan-Chung Wang // According to vector.transfer_read/write semantics, the vector can be a 126112b676deSHan-Chung Wang // slice. Thus, we have to offset the check index with `rankDiff` in 126212b676deSHan-Chung Wang // `srcStrides` and source dim sizes. 126312b676deSHan-Chung Wang size_t result = 0; 126412b676deSHan-Chung Wang int rankDiff = srcType.getRank() - vectorType.getRank(); 126512b676deSHan-Chung Wang for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) { 126612b676deSHan-Chung Wang // Check that the inner dim size is 1 for both memref type and vector slice. 126712b676deSHan-Chung Wang // It can be folded only if they are 1 and the stride is 1. 126812b676deSHan-Chung Wang int dim = vectorType.getRank() - i - 1; 126912b676deSHan-Chung Wang if (srcStrides[dim + rankDiff] != 1 || 1270e01ff823SBenjamin Maxwell srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim)) 127112b676deSHan-Chung Wang break; 127212b676deSHan-Chung Wang result++; 127312b676deSHan-Chung Wang } 127412b676deSHan-Chung Wang return result; 127512b676deSHan-Chung Wang } 127612b676deSHan-Chung Wang 127712b676deSHan-Chung Wang /// Drop inner most contiguous unit dimensions from transfer_read operand. 127812b676deSHan-Chung Wang class DropInnerMostUnitDimsTransferRead 127912b676deSHan-Chung Wang : public OpRewritePattern<vector::TransferReadOp> { 128027cc31b6SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 128199ef9eebSMatthias Springer 128299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 128399ef9eebSMatthias Springer PatternRewriter &rewriter) const override { 128499ef9eebSMatthias Springer // TODO: support 0-d corner case. 128599ef9eebSMatthias Springer if (readOp.getTransferRank() == 0) 128699ef9eebSMatthias Springer return failure(); 128799ef9eebSMatthias Springer 128899ef9eebSMatthias Springer // TODO: support mask. 12897c38fd60SJacques Pienaar if (readOp.getMask()) 129099ef9eebSMatthias Springer return failure(); 129199ef9eebSMatthias Springer 12925550c821STres Popp auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType()); 1293d592c8ecSDiego Caballero if (!srcType) 129499ef9eebSMatthias Springer return failure(); 129599ef9eebSMatthias Springer 12967c38fd60SJacques Pienaar if (!readOp.getPermutationMap().isMinorIdentity()) 129799ef9eebSMatthias Springer return failure(); 129899ef9eebSMatthias Springer 129999ef9eebSMatthias Springer auto targetType = readOp.getVectorType(); 130099ef9eebSMatthias Springer if (targetType.getRank() <= 1) 130199ef9eebSMatthias Springer return failure(); 130299ef9eebSMatthias Springer 130312b676deSHan-Chung Wang FailureOr<size_t> maybeDimsToDrop = 130412b676deSHan-Chung Wang getTransferFoldableInnerUnitDims(srcType, targetType); 130512b676deSHan-Chung Wang if (failed(maybeDimsToDrop)) 130699ef9eebSMatthias Springer return failure(); 130799ef9eebSMatthias Springer 130812b676deSHan-Chung Wang size_t dimsToDrop = maybeDimsToDrop.value(); 130999ef9eebSMatthias Springer if (dimsToDrop == 0) 131099ef9eebSMatthias Springer return failure(); 131199ef9eebSMatthias Springer 13126479a5a4SAndrzej Warzyński auto inBounds = readOp.getInBoundsValues(); 13136479a5a4SAndrzej Warzyński auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop); 13146479a5a4SAndrzej Warzyński if (llvm::is_contained(droppedInBounds, false)) 131577db8b08SAndrzej Warzyński return failure(); 131677db8b08SAndrzej Warzyński 131799ef9eebSMatthias Springer auto resultTargetVecType = 131899ef9eebSMatthias Springer VectorType::get(targetType.getShape().drop_back(dimsToDrop), 1319e01ff823SBenjamin Maxwell targetType.getElementType(), 1320e01ff823SBenjamin Maxwell targetType.getScalableDims().drop_back(dimsToDrop)); 132199ef9eebSMatthias Springer 132299ef9eebSMatthias Springer auto loc = readOp.getLoc(); 1323d592c8ecSDiego Caballero SmallVector<OpFoldResult> sizes = 1324d592c8ecSDiego Caballero memref::getMixedSizes(rewriter, loc, readOp.getSource()); 1325d592c8ecSDiego Caballero SmallVector<OpFoldResult> offsets(srcType.getRank(), 1326d592c8ecSDiego Caballero rewriter.getIndexAttr(0)); 1327d592c8ecSDiego Caballero SmallVector<OpFoldResult> strides(srcType.getRank(), 1328d592c8ecSDiego Caballero rewriter.getIndexAttr(1)); 13297c83d1bdSHan-Chung Wang auto resultMemrefType = 13307c83d1bdSHan-Chung Wang cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 13317c83d1bdSHan-Chung Wang srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, 13327c83d1bdSHan-Chung Wang strides)); 13331f5807ebSAndrzej Warzyński ArrayAttr inBoundsAttr = rewriter.getArrayAttr( 13341f5807ebSAndrzej Warzyński readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); 133599ef9eebSMatthias Springer Value rankedReducedView = rewriter.create<memref::SubViewOp>( 1336d592c8ecSDiego Caballero loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); 133799ef9eebSMatthias Springer auto permMap = getTransferMinorIdentityMap( 13385550c821STres Popp cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); 133999ef9eebSMatthias Springer Value result = rewriter.create<vector::TransferReadOp>( 134099ef9eebSMatthias Springer loc, resultTargetVecType, rankedReducedView, 13417c38fd60SJacques Pienaar readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 13427c38fd60SJacques Pienaar readOp.getPadding(), 134399ef9eebSMatthias Springer // TODO: support mask. 134499ef9eebSMatthias Springer /*mask=*/Value(), inBoundsAttr); 134599ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 134699ef9eebSMatthias Springer result); 134799ef9eebSMatthias Springer return success(); 134899ef9eebSMatthias Springer } 134999ef9eebSMatthias Springer }; 135099ef9eebSMatthias Springer 135112b676deSHan-Chung Wang /// Drop inner most contiguous unit dimensions from transfer_write operand. 135212b676deSHan-Chung Wang /// E.g., 135312b676deSHan-Chung Wang /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] 135412b676deSHan-Chung Wang /// {in_bounds = [true, true, true, true, true]} 135512b676deSHan-Chung Wang /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> 135612b676deSHan-Chung Wang /// 135712b676deSHan-Chung Wang /// will be replaced with 135812b676deSHan-Chung Wang /// 135912b676deSHan-Chung Wang /// %subview = memref.subview %arg0 136012b676deSHan-Chung Wang /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1] 136112b676deSHan-Chung Wang /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32> 136212b676deSHan-Chung Wang /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32> 136312b676deSHan-Chung Wang /// to vector<1x16x16xf32> 136412b676deSHan-Chung Wang /// vector.transfer_write %0, %subview[%c0, %arg2, %c0] 136512b676deSHan-Chung Wang /// {in_bounds = [true, true, true]} 136612b676deSHan-Chung Wang /// : vector<1x16x16xf32>, memref<1x512x16xf32> 1367c65fb32dSAndrzej Warzyński /// 1368c65fb32dSAndrzej Warzyński /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). 136912b676deSHan-Chung Wang class DropInnerMostUnitDimsTransferWrite 137012b676deSHan-Chung Wang : public OpRewritePattern<vector::TransferWriteOp> { 137112b676deSHan-Chung Wang using OpRewritePattern::OpRewritePattern; 137212b676deSHan-Chung Wang 137312b676deSHan-Chung Wang LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 137412b676deSHan-Chung Wang PatternRewriter &rewriter) const override { 137512b676deSHan-Chung Wang // TODO: support 0-d corner case. 137612b676deSHan-Chung Wang if (writeOp.getTransferRank() == 0) 137712b676deSHan-Chung Wang return failure(); 137812b676deSHan-Chung Wang 137912b676deSHan-Chung Wang // TODO: support mask. 138012b676deSHan-Chung Wang if (writeOp.getMask()) 138112b676deSHan-Chung Wang return failure(); 138212b676deSHan-Chung Wang 138312b676deSHan-Chung Wang auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType()); 1384d193ac4fSHan-Chung Wang if (!srcType) 138512b676deSHan-Chung Wang return failure(); 138612b676deSHan-Chung Wang 138712b676deSHan-Chung Wang if (!writeOp.getPermutationMap().isMinorIdentity()) 138812b676deSHan-Chung Wang return failure(); 138912b676deSHan-Chung Wang 139012b676deSHan-Chung Wang auto targetType = writeOp.getVectorType(); 139112b676deSHan-Chung Wang if (targetType.getRank() <= 1) 139212b676deSHan-Chung Wang return failure(); 139312b676deSHan-Chung Wang 139412b676deSHan-Chung Wang FailureOr<size_t> maybeDimsToDrop = 139512b676deSHan-Chung Wang getTransferFoldableInnerUnitDims(srcType, targetType); 139612b676deSHan-Chung Wang if (failed(maybeDimsToDrop)) 139712b676deSHan-Chung Wang return failure(); 139812b676deSHan-Chung Wang 139912b676deSHan-Chung Wang size_t dimsToDrop = maybeDimsToDrop.value(); 140012b676deSHan-Chung Wang if (dimsToDrop == 0) 140112b676deSHan-Chung Wang return failure(); 140212b676deSHan-Chung Wang 14036479a5a4SAndrzej Warzyński auto inBounds = writeOp.getInBoundsValues(); 14046479a5a4SAndrzej Warzyński auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop); 14056479a5a4SAndrzej Warzyński if (llvm::is_contained(droppedInBounds, false)) 14066479a5a4SAndrzej Warzyński return failure(); 14076479a5a4SAndrzej Warzyński 140812b676deSHan-Chung Wang auto resultTargetVecType = 140912b676deSHan-Chung Wang VectorType::get(targetType.getShape().drop_back(dimsToDrop), 1410e01ff823SBenjamin Maxwell targetType.getElementType(), 1411e01ff823SBenjamin Maxwell targetType.getScalableDims().drop_back(dimsToDrop)); 141212b676deSHan-Chung Wang 1413d193ac4fSHan-Chung Wang Location loc = writeOp.getLoc(); 1414d193ac4fSHan-Chung Wang SmallVector<OpFoldResult> sizes = 1415d193ac4fSHan-Chung Wang memref::getMixedSizes(rewriter, loc, writeOp.getSource()); 1416d193ac4fSHan-Chung Wang SmallVector<OpFoldResult> offsets(srcType.getRank(), 1417d193ac4fSHan-Chung Wang rewriter.getIndexAttr(0)); 1418d193ac4fSHan-Chung Wang SmallVector<OpFoldResult> strides(srcType.getRank(), 1419d193ac4fSHan-Chung Wang rewriter.getIndexAttr(1)); 14207c83d1bdSHan-Chung Wang auto resultMemrefType = 14217c83d1bdSHan-Chung Wang cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( 14227c83d1bdSHan-Chung Wang srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, 14237c83d1bdSHan-Chung Wang strides)); 14241f5807ebSAndrzej Warzyński ArrayAttr inBoundsAttr = rewriter.getArrayAttr( 14251f5807ebSAndrzej Warzyński writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)); 142612b676deSHan-Chung Wang 142712b676deSHan-Chung Wang Value rankedReducedView = rewriter.create<memref::SubViewOp>( 1428d193ac4fSHan-Chung Wang loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides); 142912b676deSHan-Chung Wang auto permMap = getTransferMinorIdentityMap( 143012b676deSHan-Chung Wang cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); 143112b676deSHan-Chung Wang 143212b676deSHan-Chung Wang auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 143312b676deSHan-Chung Wang loc, resultTargetVecType, writeOp.getVector()); 143412b676deSHan-Chung Wang rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 143512b676deSHan-Chung Wang writeOp, shapeCast, rankedReducedView, 143612b676deSHan-Chung Wang writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 143712b676deSHan-Chung Wang // TODO: support mask. 143812b676deSHan-Chung Wang /*mask=*/Value(), inBoundsAttr); 143912b676deSHan-Chung Wang return success(); 144012b676deSHan-Chung Wang } 144112b676deSHan-Chung Wang }; 144212b676deSHan-Chung Wang 1443fb7ef637SJakub Kuderski /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul 1444fb7ef637SJakub Kuderski /// semantics to a contraction suitable for MMT (matrix matrix multiplication 1445fb7ef637SJakub Kuderski /// with the RHS transposed) lowering. 1446fb7ef637SJakub Kuderski struct CanonicalizeContractMatmulToMMT final 1447fb7ef637SJakub Kuderski : OpRewritePattern<vector::ContractionOp> { 1448fb7ef637SJakub Kuderski using OpRewritePattern::OpRewritePattern; 1449fb7ef637SJakub Kuderski 1450fb7ef637SJakub Kuderski using FilterConstraintType = 1451fb7ef637SJakub Kuderski std::function<LogicalResult(vector::ContractionOp op)>; 1452fb7ef637SJakub Kuderski 1453fb7ef637SJakub Kuderski CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, 1454fb7ef637SJakub Kuderski FilterConstraintType constraint) 1455fb7ef637SJakub Kuderski : OpRewritePattern<vector::ContractionOp>(context, benefit), 1456fb7ef637SJakub Kuderski filter(std::move(constraint)) {} 1457fb7ef637SJakub Kuderski 1458fb7ef637SJakub Kuderski LogicalResult matchAndRewrite(vector::ContractionOp op, 1459fb7ef637SJakub Kuderski PatternRewriter &rewriter) const override { 1460fb7ef637SJakub Kuderski if (failed(filter(op))) 1461fb7ef637SJakub Kuderski return failure(); 1462fb7ef637SJakub Kuderski 1463fb7ef637SJakub Kuderski Location loc = op.getLoc(); 1464fb7ef637SJakub Kuderski Value lhs = op.getLhs(); 1465fb7ef637SJakub Kuderski Value rhs = op.getRhs(); 1466fb7ef637SJakub Kuderski Value res = op.getAcc(); 1467fb7ef637SJakub Kuderski 1468fb7ef637SJakub Kuderski // Set up the parallel/reduction structure in right form. 1469fb7ef637SJakub Kuderski using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1470fe8a62c4SUday Bondhugula auto infer = [&](MapList m) { 1471fe8a62c4SUday Bondhugula return AffineMap::inferFromExprList(m, op.getContext()); 1472fe8a62c4SUday Bondhugula }; 1473fb7ef637SJakub Kuderski AffineExpr m; 1474fb7ef637SJakub Kuderski AffineExpr n; 1475fb7ef637SJakub Kuderski AffineExpr k; 1476fb7ef637SJakub Kuderski bindDims(rewriter.getContext(), m, n, k); 1477fb7ef637SJakub Kuderski static constexpr std::array<int64_t, 2> perm = {1, 0}; 1478fb7ef637SJakub Kuderski auto iteratorTypes = op.getIteratorTypes().getValue(); 1479fb7ef637SJakub Kuderski SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1480fb7ef637SJakub Kuderski if (iteratorTypes.size() != 3 || 1481fb7ef637SJakub Kuderski !vector::isParallelIterator(iteratorTypes[0]) || 1482fb7ef637SJakub Kuderski !vector::isParallelIterator(iteratorTypes[1]) || 1483fb7ef637SJakub Kuderski !vector::isReductionIterator(iteratorTypes[2])) 1484fb7ef637SJakub Kuderski return rewriter.notifyMatchFailure(op, "contraction is not a gemm"); 1485fb7ef637SJakub Kuderski 1486fb7ef637SJakub Kuderski // The canonical form is "TNT" = A row-major, B col-major, C row-major. 1487fb7ef637SJakub Kuderski const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); 1488fb7ef637SJakub Kuderski if (maps == canonicalForm) 1489fb7ef637SJakub Kuderski return rewriter.notifyMatchFailure(op, "already in the canonical form"); 1490fb7ef637SJakub Kuderski 1491fb7ef637SJakub Kuderski // Create a vector transpose making sure to emit zero/sign-extend at the 1492fb7ef637SJakub Kuderski // end. 1493fb7ef637SJakub Kuderski auto createTranspose = [&rewriter, loc](Value mat) -> Value { 1494fb7ef637SJakub Kuderski if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) { 1495fb7ef637SJakub Kuderski Value trans = 1496fb7ef637SJakub Kuderski rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm); 14975041fe84SLei Zhang VectorType newType = 14984b2ba5a6SAndrzej Warzyński cast<VectorType>(trans.getType()) 14994b2ba5a6SAndrzej Warzyński .clone(cast<VectorType>(mat.getType()).getElementType()); 15005041fe84SLei Zhang return rewriter.create<arith::ExtSIOp>(loc, newType, trans); 1501fb7ef637SJakub Kuderski } 1502fb7ef637SJakub Kuderski if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) { 1503fb7ef637SJakub Kuderski Value trans = 1504fb7ef637SJakub Kuderski rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm); 15055041fe84SLei Zhang VectorType newType = 15065041fe84SLei Zhang VectorType::get(cast<VectorType>(trans.getType()).getShape(), 15075041fe84SLei Zhang cast<VectorType>(mat.getType()).getElementType()); 15085041fe84SLei Zhang return rewriter.create<arith::ExtUIOp>(loc, newType, trans); 1509fb7ef637SJakub Kuderski } 1510fb7ef637SJakub Kuderski return rewriter.create<vector::TransposeOp>(loc, mat, perm); 1511fb7ef637SJakub Kuderski }; 1512fb7ef637SJakub Kuderski 1513fb7ef637SJakub Kuderski if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1514fb7ef637SJakub Kuderski rhs = createTranspose(rhs); 1515fb7ef637SJakub Kuderski } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1516fb7ef637SJakub Kuderski lhs = createTranspose(lhs); 1517fb7ef637SJakub Kuderski } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1518fb7ef637SJakub Kuderski rhs = createTranspose(rhs); 1519fb7ef637SJakub Kuderski lhs = createTranspose(lhs); 1520fb7ef637SJakub Kuderski } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1521fb7ef637SJakub Kuderski std::swap(rhs, lhs); 1522fb7ef637SJakub Kuderski rhs = createTranspose(rhs); 1523fb7ef637SJakub Kuderski lhs = createTranspose(lhs); 1524fb7ef637SJakub Kuderski } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1525fb7ef637SJakub Kuderski std::swap(rhs, lhs); 1526fb7ef637SJakub Kuderski rhs = createTranspose(rhs); 1527fb7ef637SJakub Kuderski } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1528fb7ef637SJakub Kuderski std::swap(lhs, rhs); 1529fb7ef637SJakub Kuderski lhs = createTranspose(lhs); 1530fb7ef637SJakub Kuderski } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1531fb7ef637SJakub Kuderski std::swap(lhs, rhs); 1532fb7ef637SJakub Kuderski } else { 1533fb7ef637SJakub Kuderski return rewriter.notifyMatchFailure(op, "unhandled contraction form"); 1534fb7ef637SJakub Kuderski } 1535fb7ef637SJakub Kuderski rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1536fb7ef637SJakub Kuderski op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), 1537fb7ef637SJakub Kuderski op.getIteratorTypes()); 1538fb7ef637SJakub Kuderski return success(); 1539fb7ef637SJakub Kuderski }; 1540fb7ef637SJakub Kuderski 1541fb7ef637SJakub Kuderski private: 1542fb7ef637SJakub Kuderski FilterConstraintType filter; 1543fb7ef637SJakub Kuderski }; 1544fb7ef637SJakub Kuderski 15459a795f0cSManish Gupta /// Pattern to fold arithmetic extensions on floating point data types into 15469a795f0cSManish Gupta /// vector contraction operations. linalg.matmul introduces arithmetic 15479a795f0cSManish Gupta /// extensions on its operands. Please mlir snippets below for more details. 15489a795f0cSManish Gupta /// ```mlir 15499a795f0cSManish Gupta /// "linalg.matmul"(%lhs, %rhs, %acc) ({ 15509a795f0cSManish Gupta /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): 15519a795f0cSManish Gupta /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 15529a795f0cSManish Gupta /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 15539a795f0cSManish Gupta /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 15549a795f0cSManish Gupta /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 15559a795f0cSManish Gupta /// "linalg.yield"(%acc) : (f32) -> () 15569a795f0cSManish Gupta /// }) 15579a795f0cSManish Gupta /// ``` 15589a795f0cSManish Gupta /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor 15599a795f0cSManish Gupta /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. 15609a795f0cSManish Gupta /// This pattern folds the arithmetic extensions into the vector contraction and 15619a795f0cSManish Gupta /// enables the usage of native mixed precision Tensor Core instructions. 1562ac1e22f3SStanley Winata template <typename ExtOp> 15639a795f0cSManish Gupta struct FoldArithExtIntoContractionOp 15649a795f0cSManish Gupta : public OpRewritePattern<vector::ContractionOp> { 15659a795f0cSManish Gupta using OpRewritePattern::OpRewritePattern; 15669a795f0cSManish Gupta 15679a795f0cSManish Gupta LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 15689a795f0cSManish Gupta PatternRewriter &rewriter) const override { 15699a795f0cSManish Gupta 1570ac1e22f3SStanley Winata auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>(); 1571ac1e22f3SStanley Winata auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>(); 15729a795f0cSManish Gupta 15739a795f0cSManish Gupta if (!lhsDefOp || !rhsDefOp) { 15749a795f0cSManish Gupta return rewriter.notifyMatchFailure(contractOp, 15759a795f0cSManish Gupta "no defining op on contract operands"); 15769a795f0cSManish Gupta } 15779a795f0cSManish Gupta 15789a795f0cSManish Gupta rewriter.replaceOpWithNewOp<vector::ContractionOp>( 15799a795f0cSManish Gupta contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), 15809a795f0cSManish Gupta contractOp.getAcc(), contractOp.getIndexingMapsAttr(), 15819a795f0cSManish Gupta contractOp.getIteratorTypesAttr()); 15829a795f0cSManish Gupta 15839a795f0cSManish Gupta return success(); 15849a795f0cSManish Gupta } 15859a795f0cSManish Gupta }; 15869a795f0cSManish Gupta 1587d33bad66SJakub Kuderski /// Pattern to fold chained reduction to a series of vector additions and a 1588d33bad66SJakub Kuderski /// final reduction. This form should require fewer subgroup operations. 1589d33bad66SJakub Kuderski /// 1590d33bad66SJakub Kuderski /// ```mlir 1591d33bad66SJakub Kuderski /// %a = vector.reduction <add> %x, %acc 1592d33bad66SJakub Kuderski /// %b = vector.reduction <add> %y, %a 1593d33bad66SJakub Kuderski /// ==> 1594d33bad66SJakub Kuderski /// %a = arith.addf %x, %y 1595d33bad66SJakub Kuderski /// %b = vector.reduction <add> %a, %acc 1596d33bad66SJakub Kuderski /// ``` 1597d33bad66SJakub Kuderski struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { 1598d33bad66SJakub Kuderski using OpRewritePattern::OpRewritePattern; 1599d33bad66SJakub Kuderski 1600d33bad66SJakub Kuderski LogicalResult matchAndRewrite(vector::ReductionOp op, 1601d33bad66SJakub Kuderski PatternRewriter &rewriter) const override { 1602d33bad66SJakub Kuderski // TODO: Handle other combining kinds. 1603d33bad66SJakub Kuderski if (op.getKind() != vector::CombiningKind::ADD) 1604d33bad66SJakub Kuderski return failure(); 1605d33bad66SJakub Kuderski 1606d33bad66SJakub Kuderski // Accumulator is optional. 1607d33bad66SJakub Kuderski Value acc = op.getAcc(); 1608d33bad66SJakub Kuderski if (!acc) 1609d33bad66SJakub Kuderski return failure(); 1610d33bad66SJakub Kuderski 1611d33bad66SJakub Kuderski if (!acc.getType().isIntOrFloat()) 1612d33bad66SJakub Kuderski return failure(); 1613d33bad66SJakub Kuderski 1614d33bad66SJakub Kuderski auto parentReduction = acc.getDefiningOp<vector::ReductionOp>(); 1615d33bad66SJakub Kuderski if (!parentReduction) 1616d33bad66SJakub Kuderski return failure(); 1617d33bad66SJakub Kuderski 1618d33bad66SJakub Kuderski Location loc = op.getLoc(); 1619d33bad66SJakub Kuderski Value vAdd; 1620d33bad66SJakub Kuderski if (isa<IntegerType>(acc.getType())) { 1621d33bad66SJakub Kuderski vAdd = rewriter.createOrFold<arith::AddIOp>( 1622d33bad66SJakub Kuderski loc, parentReduction.getVector(), op.getVector()); 1623d33bad66SJakub Kuderski } else { 1624d33bad66SJakub Kuderski vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(), 1625d33bad66SJakub Kuderski op.getVector()); 1626d33bad66SJakub Kuderski } 1627d33bad66SJakub Kuderski rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd, 1628d33bad66SJakub Kuderski parentReduction.getAcc()); 1629d33bad66SJakub Kuderski return success(); 1630d33bad66SJakub Kuderski } 1631d33bad66SJakub Kuderski }; 1632d33bad66SJakub Kuderski 1633de61875eSHugo Trachino // Helper function dropping unit non-scalable dimension from a VectorType 1634de61875eSHugo Trachino // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit 1635de61875eSHugo Trachino // dimensions are not dropped. Folding such dimensions would require "shifting" 1636de61875eSHugo Trachino // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> -> 1637de61875eSHugo Trachino // vector<[4]xf32>). This could be implemented in the future. 1638de61875eSHugo Trachino static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) { 1639de61875eSHugo Trachino auto inVecShape = inVecTy.getShape(); 1640de61875eSHugo Trachino SmallVector<int64_t> newShape; 1641de61875eSHugo Trachino SmallVector<bool> newScalableDims; 1642de61875eSHugo Trachino for (auto [dim, isScalable] : 1643de61875eSHugo Trachino llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) { 1644de61875eSHugo Trachino if (dim == 1 && !isScalable) 1645de61875eSHugo Trachino continue; 1646de61875eSHugo Trachino 1647de61875eSHugo Trachino newShape.push_back(dim); 1648de61875eSHugo Trachino newScalableDims.push_back(isScalable); 1649de61875eSHugo Trachino } 1650de61875eSHugo Trachino // All dims have been dropped, return vector<1xeType>. 1651de61875eSHugo Trachino if (newShape.empty()) { 1652de61875eSHugo Trachino newShape.push_back(1); 1653de61875eSHugo Trachino newScalableDims.push_back(false); 1654de61875eSHugo Trachino } 1655de61875eSHugo Trachino 1656de61875eSHugo Trachino return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims); 1657de61875eSHugo Trachino } 1658de61875eSHugo Trachino 1659de61875eSHugo Trachino /// For vectors with at least one unit dim, replaces: 1660c02d07fdSAndrzej Warzyński /// elementwise(a, b) 1661c02d07fdSAndrzej Warzyński /// with: 1662c02d07fdSAndrzej Warzyński /// sc_a = shape_cast(a) 1663c02d07fdSAndrzej Warzyński /// sc_b = shape_cast(b) 1664c02d07fdSAndrzej Warzyński /// res = elementwise(sc_a, sc_b) 1665c02d07fdSAndrzej Warzyński /// return shape_cast(res) 1666c02d07fdSAndrzej Warzyński /// The newly inserted shape_cast Ops fold (before elementwise Op) and then 1667c02d07fdSAndrzej Warzyński /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are 1668c02d07fdSAndrzej Warzyński /// required to be rank > 1. 1669c02d07fdSAndrzej Warzyński /// 1670c02d07fdSAndrzej Warzyński /// Ex: 1671c02d07fdSAndrzej Warzyński /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> 1672c02d07fdSAndrzej Warzyński /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> 1673c02d07fdSAndrzej Warzyński /// 1674c02d07fdSAndrzej Warzyński /// gets converted to: 1675c02d07fdSAndrzej Warzyński /// 1676c02d07fdSAndrzej Warzyński /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> 1677c02d07fdSAndrzej Warzyński /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> 1678c02d07fdSAndrzej Warzyński /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> 1679c02d07fdSAndrzej Warzyński /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> 1680c02d07fdSAndrzej Warzyński /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32> 1681c02d07fdSAndrzej Warzyński /// 1682c02d07fdSAndrzej Warzyński /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and 1683c02d07fdSAndrzej Warzyński /// `%cast`. 1684c02d07fdSAndrzej Warzyński struct DropUnitDimFromElementwiseOps final 1685c02d07fdSAndrzej Warzyński : public OpTraitRewritePattern<OpTrait::Elementwise> { 1686c02d07fdSAndrzej Warzyński using OpTraitRewritePattern::OpTraitRewritePattern; 1687c02d07fdSAndrzej Warzyński LogicalResult matchAndRewrite(Operation *op, 1688c02d07fdSAndrzej Warzyński PatternRewriter &rewriter) const override { 16892c9ba9c3SJerry Wu if (op->getNumResults() != 1 || op->getNumRegions() != 0) 1690c02d07fdSAndrzej Warzyński return failure(); 1691c02d07fdSAndrzej Warzyński 16922c9ba9c3SJerry Wu auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType()); 16932c9ba9c3SJerry Wu if (!resultVectorType) 16942c9ba9c3SJerry Wu return failure(); 16952c9ba9c3SJerry Wu 169613b37626SDiego Caballero // Check the operand pre-conditions. For `Elementwise` ops all operands are 169713b37626SDiego Caballero // guaranteed to have identical shapes (with some exceptions such as 169813b37626SDiego Caballero // `arith.select`) and it suffices to only check one of them. 169913b37626SDiego Caballero auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType()); 1700eaabd762SHan-Chung Wang if (!sourceVectorType) 170113b37626SDiego Caballero return failure(); 1702eaabd762SHan-Chung Wang if (sourceVectorType.getRank() < 2) 1703eaabd762SHan-Chung Wang return failure(); 1704eaabd762SHan-Chung Wang 1705c02d07fdSAndrzej Warzyński SmallVector<Value> newOperands; 1706c02d07fdSAndrzej Warzyński auto loc = op->getLoc(); 1707c02d07fdSAndrzej Warzyński for (auto operand : op->getOperands()) { 17082c9ba9c3SJerry Wu auto opVectorType = cast<VectorType>(operand.getType()); 1709de61875eSHugo Trachino auto newVType = dropNonScalableUnitDimFromType(opVectorType); 1710de61875eSHugo Trachino if (newVType == opVectorType) 1711de61875eSHugo Trachino return rewriter.notifyMatchFailure(op, "No unit dimension to remove."); 1712de61875eSHugo Trachino 1713c02d07fdSAndrzej Warzyński auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand); 1714c02d07fdSAndrzej Warzyński newOperands.push_back(opSC); 1715c02d07fdSAndrzej Warzyński } 1716c02d07fdSAndrzej Warzyński 17172c9ba9c3SJerry Wu VectorType newResultVectorType = 1718de61875eSHugo Trachino dropNonScalableUnitDimFromType(resultVectorType); 1719de61875eSHugo Trachino // Create an updated elementwise Op without unit dim. 1720c02d07fdSAndrzej Warzyński Operation *elementwiseOp = 1721c02d07fdSAndrzej Warzyński rewriter.create(loc, op->getName().getIdentifier(), newOperands, 17222c9ba9c3SJerry Wu newResultVectorType, op->getAttrs()); 1723c02d07fdSAndrzej Warzyński 1724de61875eSHugo Trachino // Restore the unit dim by applying vector.shape_cast to the result. 17252c9ba9c3SJerry Wu rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType, 1726c02d07fdSAndrzej Warzyński elementwiseOp->getResult(0)); 1727c02d07fdSAndrzej Warzyński 1728c02d07fdSAndrzej Warzyński return success(); 1729c02d07fdSAndrzej Warzyński } 1730c02d07fdSAndrzej Warzyński }; 1731c02d07fdSAndrzej Warzyński 1732da8778e4SBenjamin Maxwell /// A pattern to drop unit dims from vector.transpose. 1733da8778e4SBenjamin Maxwell /// 1734da8778e4SBenjamin Maxwell /// Example: 1735da8778e4SBenjamin Maxwell /// 1736da8778e4SBenjamin Maxwell /// BEFORE: 1737da8778e4SBenjamin Maxwell /// ```mlir 1738da8778e4SBenjamin Maxwell /// %transpose = vector.transpose %vector, [3, 0, 1, 2] 1739da8778e4SBenjamin Maxwell /// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> 1740da8778e4SBenjamin Maxwell /// ``` 1741da8778e4SBenjamin Maxwell /// 1742da8778e4SBenjamin Maxwell /// AFTER: 1743da8778e4SBenjamin Maxwell /// ```mlir 1744da8778e4SBenjamin Maxwell /// %dropDims = vector.shape_cast %vector 1745da8778e4SBenjamin Maxwell /// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> 1746da8778e4SBenjamin Maxwell /// %transpose = vector.transpose %0, [1, 0] 1747da8778e4SBenjamin Maxwell /// : vector<4x[4]xf32> to vector<[4]x4xf32> 1748da8778e4SBenjamin Maxwell /// %restoreDims = vector.shape_cast %transpose 1749da8778e4SBenjamin Maxwell /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> 1750da8778e4SBenjamin Maxwell /// ``` 1751da8778e4SBenjamin Maxwell struct DropUnitDimsFromTransposeOp final 1752da8778e4SBenjamin Maxwell : OpRewritePattern<vector::TransposeOp> { 1753da8778e4SBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 1754da8778e4SBenjamin Maxwell 1755da8778e4SBenjamin Maxwell LogicalResult matchAndRewrite(vector::TransposeOp op, 1756da8778e4SBenjamin Maxwell PatternRewriter &rewriter) const override { 1757da8778e4SBenjamin Maxwell VectorType sourceType = op.getSourceVectorType(); 1758da8778e4SBenjamin Maxwell VectorType sourceTypeWithoutUnitDims = 1759da8778e4SBenjamin Maxwell dropNonScalableUnitDimFromType(sourceType); 1760da8778e4SBenjamin Maxwell 1761da8778e4SBenjamin Maxwell if (sourceType == sourceTypeWithoutUnitDims) 1762da8778e4SBenjamin Maxwell return failure(); 1763da8778e4SBenjamin Maxwell 1764da8778e4SBenjamin Maxwell // Construct a map from dimIdx -> number of dims dropped before dimIdx. 1765da8778e4SBenjamin Maxwell auto sourceDims = llvm::to_vector(vector::getDims(sourceType)); 1766da8778e4SBenjamin Maxwell SmallVector<int64_t> droppedDimsBefore(sourceType.getRank()); 1767da8778e4SBenjamin Maxwell int64_t droppedDims = 0; 1768da8778e4SBenjamin Maxwell for (auto [i, dim] : llvm::enumerate(sourceDims)) { 1769da8778e4SBenjamin Maxwell droppedDimsBefore[i] = droppedDims; 1770da8778e4SBenjamin Maxwell if (dim == std::make_tuple(1, false)) 1771da8778e4SBenjamin Maxwell ++droppedDims; 1772da8778e4SBenjamin Maxwell } 1773da8778e4SBenjamin Maxwell 1774da8778e4SBenjamin Maxwell // Drop unit dims from transpose permutation. 1775da8778e4SBenjamin Maxwell ArrayRef<int64_t> perm = op.getPermutation(); 1776da8778e4SBenjamin Maxwell SmallVector<int64_t> newPerm; 1777da8778e4SBenjamin Maxwell for (int64_t idx : perm) { 1778da8778e4SBenjamin Maxwell if (sourceDims[idx] == std::make_tuple(1, false)) 1779da8778e4SBenjamin Maxwell continue; 1780da8778e4SBenjamin Maxwell newPerm.push_back(idx - droppedDimsBefore[idx]); 1781da8778e4SBenjamin Maxwell } 1782da8778e4SBenjamin Maxwell 1783201da87cSHan-Chung Wang // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT> 1784201da87cSHan-Chung Wang // type when the dimensions are unit dimensions. In this case, the newPerm 1785201da87cSHan-Chung Wang // should be [0]. 1786201da87cSHan-Chung Wang if (newPerm.empty()) { 1787201da87cSHan-Chung Wang newPerm.push_back(0); 1788201da87cSHan-Chung Wang } 1789201da87cSHan-Chung Wang 1790da8778e4SBenjamin Maxwell Location loc = op.getLoc(); 1791da8778e4SBenjamin Maxwell // Drop the unit dims via shape_cast. 1792da8778e4SBenjamin Maxwell auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>( 1793da8778e4SBenjamin Maxwell loc, sourceTypeWithoutUnitDims, op.getVector()); 1794da8778e4SBenjamin Maxwell // Create the new transpose. 1795*aa295216SJay Foad auto transposeWithoutUnitDims = 1796da8778e4SBenjamin Maxwell rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm); 1797da8778e4SBenjamin Maxwell // Restore the unit dims via shape cast. 1798da8778e4SBenjamin Maxwell rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 1799*aa295216SJay Foad op, op.getResultVectorType(), transposeWithoutUnitDims); 1800da8778e4SBenjamin Maxwell 1801da492d43SBenjamin Maxwell return success(); 1802da8778e4SBenjamin Maxwell } 1803da8778e4SBenjamin Maxwell }; 1804da8778e4SBenjamin Maxwell 1805a3b34e67SQuinn Dawkins /// A pattern to drop unit dims from the iter_args of an scf.for. 1806a3b34e67SQuinn Dawkins /// 1807a3b34e67SQuinn Dawkins /// Example: 1808a3b34e67SQuinn Dawkins /// 1809a3b34e67SQuinn Dawkins /// BEFORE: 1810a3b34e67SQuinn Dawkins /// ```mlir 1811a3b34e67SQuinn Dawkins /// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> { 1812a3b34e67SQuinn Dawkins /// ... 1813a3b34e67SQuinn Dawkins /// scf.yield % 1814a3b34e67SQuinn Dawkins /// } 1815a3b34e67SQuinn Dawkins /// ``` 1816a3b34e67SQuinn Dawkins /// 1817a3b34e67SQuinn Dawkins /// AFTER: 1818a3b34e67SQuinn Dawkins /// ```mlir 1819a3b34e67SQuinn Dawkins /// %drop = vector.shape_cast %init 1820a3b34e67SQuinn Dawkins /// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32> 1821a3b34e67SQuinn Dawkins /// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> { 1822a3b34e67SQuinn Dawkins /// %new_iter = vector.shape_cast %iter 1823a3b34e67SQuinn Dawkins /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> 1824a3b34e67SQuinn Dawkins /// ... 1825a3b34e67SQuinn Dawkins /// } 1826a3b34e67SQuinn Dawkins /// %res = vector.shape_cast %new_loop 1827a3b34e67SQuinn Dawkins /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> 1828a3b34e67SQuinn Dawkins /// ``` 1829a3b34e67SQuinn Dawkins struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { 1830a3b34e67SQuinn Dawkins using OpRewritePattern::OpRewritePattern; 1831a3b34e67SQuinn Dawkins 1832a3b34e67SQuinn Dawkins LogicalResult matchAndRewrite(scf::ForOp forOp, 1833a3b34e67SQuinn Dawkins PatternRewriter &rewriter) const override { 1834a3b34e67SQuinn Dawkins /// Find the first iter_arg with droppable unit dims. Further applications 1835a3b34e67SQuinn Dawkins /// of this pattern will apply to later arguments. 1836a3b34e67SQuinn Dawkins for (OpOperand &operand : forOp.getInitArgsMutable()) { 1837a3b34e67SQuinn Dawkins auto vectorType = dyn_cast<VectorType>(operand.get().getType()); 1838a3b34e67SQuinn Dawkins if (!vectorType) 1839a3b34e67SQuinn Dawkins continue; 1840a3b34e67SQuinn Dawkins 1841a3b34e67SQuinn Dawkins VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType); 1842a3b34e67SQuinn Dawkins if (vectorType == newVectorType) 1843a3b34e67SQuinn Dawkins continue; 1844a3b34e67SQuinn Dawkins 1845a3b34e67SQuinn Dawkins // Create a new ForOp with that iter operand replaced. 1846a3b34e67SQuinn Dawkins auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) { 1847a3b34e67SQuinn Dawkins return b.create<vector::ShapeCastOp>(loc, type, source); 1848a3b34e67SQuinn Dawkins }; 1849a3b34e67SQuinn Dawkins 1850a3b34e67SQuinn Dawkins Value replacement = 1851a3b34e67SQuinn Dawkins castFn(rewriter, forOp.getLoc(), newVectorType, operand.get()); 1852a3b34e67SQuinn Dawkins rewriter.replaceOp(forOp, 1853a3b34e67SQuinn Dawkins replaceAndCastForOpIterArg(rewriter, forOp, operand, 1854a3b34e67SQuinn Dawkins replacement, castFn)); 1855a3b34e67SQuinn Dawkins return success(); 1856a3b34e67SQuinn Dawkins } 1857a3b34e67SQuinn Dawkins return failure(); 1858a3b34e67SQuinn Dawkins } 1859a3b34e67SQuinn Dawkins }; 1860a3b34e67SQuinn Dawkins 1861d33bad66SJakub Kuderski /// Pattern to eliminate redundant zero-constants added to reduction operands. 1862d33bad66SJakub Kuderski /// It's enough for there to be one initial zero value, so we can eliminate the 1863d33bad66SJakub Kuderski /// extra ones that feed into `vector.reduction <add>`. These get created by the 1864d33bad66SJakub Kuderski /// `ChainedReduction` pattern. 1865d33bad66SJakub Kuderski /// 1866d33bad66SJakub Kuderski /// ```mlir 1867d33bad66SJakub Kuderski /// %a = arith.addf %x, %zero 1868d33bad66SJakub Kuderski /// %b = arith.addf %a, %y 1869d33bad66SJakub Kuderski /// %c = vector.reduction <add> %b, %acc 1870d33bad66SJakub Kuderski /// ==> 1871d33bad66SJakub Kuderski /// %b = arith.addf %a, %y 1872d33bad66SJakub Kuderski /// %c = vector.reduction <add> %b, %acc 1873d33bad66SJakub Kuderski /// ``` 1874d33bad66SJakub Kuderski struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { 1875d33bad66SJakub Kuderski using OpRewritePattern::OpRewritePattern; 1876d33bad66SJakub Kuderski 1877d33bad66SJakub Kuderski LogicalResult matchAndRewrite(vector::ReductionOp op, 1878d33bad66SJakub Kuderski PatternRewriter &rewriter) const override { 1879d33bad66SJakub Kuderski // TODO: Handle other reduction kinds and their identity values. 1880d33bad66SJakub Kuderski if (op.getKind() != vector::CombiningKind::ADD) 1881d33bad66SJakub Kuderski return failure(); 1882d33bad66SJakub Kuderski 1883d33bad66SJakub Kuderski Type elemType = op.getSourceVectorType().getElementType(); 1884d33bad66SJakub Kuderski // The integer case should be handled by `arith.addi` folders, only check 1885d33bad66SJakub Kuderski // for floats here. 1886d33bad66SJakub Kuderski if (!isa<FloatType>(elemType)) 1887d33bad66SJakub Kuderski return failure(); 1888d33bad66SJakub Kuderski 1889d33bad66SJakub Kuderski auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>(); 1890d33bad66SJakub Kuderski if (!vAdd) 1891d33bad66SJakub Kuderski return failure(); 1892d33bad66SJakub Kuderski auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>(); 1893d33bad66SJakub Kuderski if (!addLhs) 1894d33bad66SJakub Kuderski return failure(); 1895d33bad66SJakub Kuderski 1896d33bad66SJakub Kuderski if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat())) 1897d33bad66SJakub Kuderski return failure(); 1898d33bad66SJakub Kuderski 1899d33bad66SJakub Kuderski auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(), 1900d33bad66SJakub Kuderski vAdd.getRhs()); 1901d33bad66SJakub Kuderski rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd, 1902d33bad66SJakub Kuderski op.getAcc()); 1903d33bad66SJakub Kuderski return success(); 1904d33bad66SJakub Kuderski } 1905d33bad66SJakub Kuderski }; 1906d33bad66SJakub Kuderski 190707677113SJakub Kuderski /// Example: 190807677113SJakub Kuderski /// ``` 190907677113SJakub Kuderski /// %a = vector.reduction <add> %x : vector<2xf32> into f32 191007677113SJakub Kuderski /// ``` 191107677113SJakub Kuderski /// is transformed into: 191207677113SJakub Kuderski /// ``` 191307677113SJakub Kuderski /// %y = vector.extract %x[0] : f32 from vector<2xf32> 191407677113SJakub Kuderski /// %z = vector.extract %x[1] : f32 from vector<2xf32> 191507677113SJakub Kuderski /// %a = arith.addf %y, %z : f32 191607677113SJakub Kuderski /// ``` 191707677113SJakub Kuderski struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { 191807677113SJakub Kuderski BreakDownVectorReduction(MLIRContext *context, 191907677113SJakub Kuderski unsigned maxNumElementsToExtract, 192007677113SJakub Kuderski PatternBenefit benefit) 192107677113SJakub Kuderski : OpRewritePattern(context, benefit), 192207677113SJakub Kuderski maxNumElementsToExtract(maxNumElementsToExtract) {} 192307677113SJakub Kuderski 192407677113SJakub Kuderski LogicalResult matchAndRewrite(vector::ReductionOp op, 192507677113SJakub Kuderski PatternRewriter &rewriter) const override { 192607677113SJakub Kuderski VectorType type = op.getSourceVectorType(); 192707677113SJakub Kuderski if (type.isScalable() || op.isMasked()) 192807677113SJakub Kuderski return failure(); 192907677113SJakub Kuderski assert(type.getRank() == 1 && "Expected a 1-d vector"); 193007677113SJakub Kuderski 193107677113SJakub Kuderski int64_t numElems = type.getNumElements(); 193207677113SJakub Kuderski if (numElems > maxNumElementsToExtract) { 193307677113SJakub Kuderski return rewriter.notifyMatchFailure( 193407677113SJakub Kuderski op, llvm::formatv("has too many vector elements ({0}) to break down " 193507677113SJakub Kuderski "(max allowed: {1})", 193607677113SJakub Kuderski numElems, maxNumElementsToExtract)); 193707677113SJakub Kuderski } 193807677113SJakub Kuderski 193907677113SJakub Kuderski Location loc = op.getLoc(); 194007677113SJakub Kuderski SmallVector<Value> extracted(numElems, nullptr); 194107677113SJakub Kuderski for (auto [idx, extractedElem] : llvm::enumerate(extracted)) 194207677113SJakub Kuderski extractedElem = rewriter.create<vector::ExtractOp>( 194307677113SJakub Kuderski loc, op.getVector(), static_cast<int64_t>(idx)); 194407677113SJakub Kuderski 194507677113SJakub Kuderski Value res = extracted.front(); 194607677113SJakub Kuderski for (auto extractedElem : llvm::drop_begin(extracted)) 194707677113SJakub Kuderski res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, 194807677113SJakub Kuderski extractedElem, op.getFastmathAttr()); 194907677113SJakub Kuderski if (Value acc = op.getAcc()) 195007677113SJakub Kuderski res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc, 195107677113SJakub Kuderski op.getFastmathAttr()); 195207677113SJakub Kuderski 195307677113SJakub Kuderski rewriter.replaceOp(op, res); 195407677113SJakub Kuderski return success(); 195507677113SJakub Kuderski } 195607677113SJakub Kuderski 195707677113SJakub Kuderski private: 195807677113SJakub Kuderski unsigned maxNumElementsToExtract = 0; 195907677113SJakub Kuderski }; 196007677113SJakub Kuderski 19619f0aa05bSHugo Trachino /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, 19629f0aa05bSHugo Trachino /// B)`. 19639f0aa05bSHugo Trachino /// Example: 19649f0aa05bSHugo Trachino /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> 19659f0aa05bSHugo Trachino /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to 19669f0aa05bSHugo Trachino /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to 19679f0aa05bSHugo Trachino /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> 19689f0aa05bSHugo Trachino /// 19699f0aa05bSHugo Trachino /// Becomes : 19709f0aa05bSHugo Trachino /// 19719f0aa05bSHugo Trachino /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> 19729f0aa05bSHugo Trachino /// 19739f0aa05bSHugo Trachino /// Supports only 1D-to-2D broadcasts. The following cases are not supported. 19749f0aa05bSHugo Trachino /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> 19759f0aa05bSHugo Trachino /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> 19769f0aa05bSHugo Trachino /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> 19779f0aa05bSHugo Trachino template <typename MulOpType> 19789f0aa05bSHugo Trachino struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { 19799f0aa05bSHugo Trachino using OpRewritePattern<MulOpType>::OpRewritePattern; 19809f0aa05bSHugo Trachino // Returns whether a vector.broadcast matches requirements for an outerproduct 19819f0aa05bSHugo Trachino // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension. 19829f0aa05bSHugo Trachino bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const { 19839f0aa05bSHugo Trachino // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating 19849f0aa05bSHugo Trachino // shape_casts/broadcasts which does not belong in this pattern. 19859f0aa05bSHugo Trachino if (!broadcastOp.computeBroadcastedUnitDims().empty()) 19869f0aa05bSHugo Trachino return false; 19879f0aa05bSHugo Trachino // Avoid broadcast like f32 or vector<f32> -> ResType 19889f0aa05bSHugo Trachino auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); 19899f0aa05bSHugo Trachino return srcType && srcType.getRank() != 2; 19909f0aa05bSHugo Trachino } 19919f0aa05bSHugo Trachino 19929f0aa05bSHugo Trachino LogicalResult matchAndRewrite(MulOpType mulOp, 19939f0aa05bSHugo Trachino PatternRewriter &rewriter) const override { 19949f0aa05bSHugo Trachino auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); 19959f0aa05bSHugo Trachino if (!resType) 19969f0aa05bSHugo Trachino return failure(); 19979f0aa05bSHugo Trachino if (resType.getRank() != 2) 19989f0aa05bSHugo Trachino return failure(); 19999f0aa05bSHugo Trachino /// If operandA can be written as tr(broadcast(A)) and operandB as 20009f0aa05bSHugo Trachino /// broadcast(B) where broadcasts are 1D-to-2D, create and return 20019f0aa05bSHugo Trachino /// vector.outerproduct(A, B). Returns failure() otherwise. 20029f0aa05bSHugo Trachino auto matchOuterProduct = 20039f0aa05bSHugo Trachino [&](Value operandA, 20049f0aa05bSHugo Trachino Value operandB) -> FailureOr<vector::OuterProductOp> { 20059f0aa05bSHugo Trachino auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>(); 20069f0aa05bSHugo Trachino if (!transposedLhs) 20079f0aa05bSHugo Trachino return failure(); 20089f0aa05bSHugo Trachino // Fail unless this is a true 2-D matrix transpose. 20099f0aa05bSHugo Trachino ArrayRef<int64_t> permutation = transposedLhs.getPermutation(); 20109f0aa05bSHugo Trachino if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0) 20119f0aa05bSHugo Trachino return failure(); 20129f0aa05bSHugo Trachino 20139f0aa05bSHugo Trachino auto broadcastedLhs = 20149f0aa05bSHugo Trachino transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>(); 20159f0aa05bSHugo Trachino if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs)) 20169f0aa05bSHugo Trachino return failure(); 20179f0aa05bSHugo Trachino 20189f0aa05bSHugo Trachino auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>(); 20199f0aa05bSHugo Trachino if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) 20209f0aa05bSHugo Trachino return failure(); 20219f0aa05bSHugo Trachino 20229f0aa05bSHugo Trachino return rewriter.create<vector::OuterProductOp>( 20239f0aa05bSHugo Trachino mulOp->getLoc(), resType, broadcastedLhs.getSource(), 20249f0aa05bSHugo Trachino broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); 20259f0aa05bSHugo Trachino }; 20269f0aa05bSHugo Trachino 20279f0aa05bSHugo Trachino Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1); 20289f0aa05bSHugo Trachino auto maybeOuterP = matchOuterProduct(lhs, rhs); 20299f0aa05bSHugo Trachino // Handle commutativity, the transposed op is the outerproduct LHS. 20309f0aa05bSHugo Trachino if (failed(maybeOuterP)) 20319f0aa05bSHugo Trachino maybeOuterP = matchOuterProduct(rhs, lhs); 20329f0aa05bSHugo Trachino if (failed(maybeOuterP)) 20339f0aa05bSHugo Trachino return failure(); 20349f0aa05bSHugo Trachino rewriter.replaceOp(mulOp, maybeOuterP->getResult()); 20359f0aa05bSHugo Trachino return success(); 20369f0aa05bSHugo Trachino } 20379f0aa05bSHugo Trachino }; 20389f0aa05bSHugo Trachino 203999ef9eebSMatthias Springer } // namespace 204099ef9eebSMatthias Springer 20419a795f0cSManish Gupta void mlir::vector::populateFoldArithExtensionPatterns( 20429a795f0cSManish Gupta RewritePatternSet &patterns) { 2043ac1e22f3SStanley Winata patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>, 2044ac1e22f3SStanley Winata FoldArithExtIntoContractionOp<arith::ExtSIOp>>( 2045ac1e22f3SStanley Winata patterns.getContext()); 20469a795f0cSManish Gupta } 20479a795f0cSManish Gupta 204899ef9eebSMatthias Springer void mlir::vector::populateVectorMaskMaterializationPatterns( 204927cc31b6SNicolas Vasilache RewritePatternSet &patterns, bool force32BitVectorIndices, 205027cc31b6SNicolas Vasilache PatternBenefit benefit) { 205199ef9eebSMatthias Springer patterns.add<VectorCreateMaskOpConversion, 205299ef9eebSMatthias Springer MaterializeTransferMask<vector::TransferReadOp>, 205399ef9eebSMatthias Springer MaterializeTransferMask<vector::TransferWriteOp>>( 205427cc31b6SNicolas Vasilache patterns.getContext(), force32BitVectorIndices, benefit); 205515a08cf2SDiego Caballero patterns.add<FoldI1Select>(patterns.getContext(), benefit); 205699ef9eebSMatthias Springer } 205799ef9eebSMatthias Springer 205827cc31b6SNicolas Vasilache void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, 205927cc31b6SNicolas Vasilache PatternBenefit benefit) { 206027cc31b6SNicolas Vasilache patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit); 206199ef9eebSMatthias Springer } 206299ef9eebSMatthias Springer 2063c02d07fdSAndrzej Warzyński void mlir::vector::populateDropUnitDimWithShapeCastPatterns( 2064c02d07fdSAndrzej Warzyński RewritePatternSet &patterns, PatternBenefit benefit) { 20651f5e8263SAndrzej Warzyński // TODO: Consider either: 20661f5e8263SAndrzej Warzyński // * including DropInnerMostUnitDimsTransferRead and 20671f5e8263SAndrzej Warzyński // DropInnerMostUnitDimsTransferWrite, or 20681f5e8263SAndrzej Warzyński // * better naming to distinguish this and 20691f5e8263SAndrzej Warzyński // populateVectorTransferCollapseInnerMostContiguousDimsPatterns. 20701f5e8263SAndrzej Warzyński patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp, 20711f5e8263SAndrzej Warzyński DropUnitDimsFromTransposeOp, ShapeCastOpFolder>( 2072a3b34e67SQuinn Dawkins patterns.getContext(), benefit); 2073c02d07fdSAndrzej Warzyński } 2074c02d07fdSAndrzej Warzyński 207599ef9eebSMatthias Springer void mlir::vector::populateBubbleVectorBitCastOpPatterns( 207627cc31b6SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 207799ef9eebSMatthias Springer patterns.add<BubbleDownVectorBitCastForExtract, 207899ef9eebSMatthias Springer BubbleDownBitCastForStridedSliceExtract, 20794623c114SDiego Caballero BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>( 20804623c114SDiego Caballero patterns.getContext(), benefit); 208199ef9eebSMatthias Springer } 208299ef9eebSMatthias Springer 2083650f04feSQuinn Dawkins void mlir::vector::populateBreakDownVectorBitCastOpPatterns( 2084650f04feSQuinn Dawkins RewritePatternSet &patterns, 2085650f04feSQuinn Dawkins std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) { 2086650f04feSQuinn Dawkins patterns.add<BreakDownVectorBitCast>(patterns.getContext(), 2087650f04feSQuinn Dawkins std::move(controlFn), benefit); 2088650f04feSQuinn Dawkins } 2089650f04feSQuinn Dawkins 2090fb7ef637SJakub Kuderski void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( 2091fb7ef637SJakub Kuderski RewritePatternSet &patterns, 2092fb7ef637SJakub Kuderski std::function<LogicalResult(vector::ContractionOp)> constraint, 2093fb7ef637SJakub Kuderski PatternBenefit benefit) { 2094fb7ef637SJakub Kuderski patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit, 2095fb7ef637SJakub Kuderski std::move(constraint)); 2096fb7ef637SJakub Kuderski } 2097fb7ef637SJakub Kuderski 209899ef9eebSMatthias Springer void mlir::vector::populateVectorReductionToContractPatterns( 209927cc31b6SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 210099ef9eebSMatthias Springer patterns.add<MultiReduceToContract, CombineContractBroadcast, 210142944da5SAndrzej Warzyński CombineContractABTranspose, CombineContractResultTranspose>( 2102f0c93fd4SLei Zhang patterns.getContext(), benefit); 210399ef9eebSMatthias Springer } 210499ef9eebSMatthias Springer 210599ef9eebSMatthias Springer void mlir::vector:: 210699ef9eebSMatthias Springer populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 210727cc31b6SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 210812b676deSHan-Chung Wang patterns.add<DropInnerMostUnitDimsTransferRead, 210912b676deSHan-Chung Wang DropInnerMostUnitDimsTransferWrite>(patterns.getContext(), 211012b676deSHan-Chung Wang benefit); 211199ef9eebSMatthias Springer } 211299ef9eebSMatthias Springer 211342944da5SAndrzej Warzyński void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, 211442944da5SAndrzej Warzyński PatternBenefit benefit) { 211542944da5SAndrzej Warzyński patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast, 211642944da5SAndrzej Warzyński ReorderElementwiseOpsOnBroadcast>(patterns.getContext(), 211742944da5SAndrzej Warzyński benefit); 21184d339ec9SAndrzej Warzynski } 21194d339ec9SAndrzej Warzynski 2120d33bad66SJakub Kuderski void mlir::vector::populateChainedVectorReductionFoldingPatterns( 2121d33bad66SJakub Kuderski RewritePatternSet &patterns, PatternBenefit benefit) { 2122d33bad66SJakub Kuderski patterns.add<ChainedReduction>(patterns.getContext(), benefit); 2123d33bad66SJakub Kuderski patterns.add<ReduceRedundantZero>(patterns.getContext(), 2124d33bad66SJakub Kuderski PatternBenefit(benefit.getBenefit() + 1)); 2125d33bad66SJakub Kuderski } 2126d33bad66SJakub Kuderski 212707677113SJakub Kuderski void mlir::vector::populateBreakDownVectorReductionPatterns( 212807677113SJakub Kuderski RewritePatternSet &patterns, unsigned maxNumElementsToExtract, 212907677113SJakub Kuderski PatternBenefit benefit) { 213007677113SJakub Kuderski patterns.add<BreakDownVectorReduction>(patterns.getContext(), 213107677113SJakub Kuderski maxNumElementsToExtract, benefit); 213207677113SJakub Kuderski } 213307677113SJakub Kuderski 21349f0aa05bSHugo Trachino void mlir::vector::populateElementwiseToVectorOpsPatterns( 21359f0aa05bSHugo Trachino RewritePatternSet &patterns) { 21369f0aa05bSHugo Trachino patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>, 21379f0aa05bSHugo Trachino FoldArithToVectorOuterProduct<arith::MulIOp>>( 21389f0aa05bSHugo Trachino patterns.getContext()); 21399f0aa05bSHugo Trachino } 21409f0aa05bSHugo Trachino 2141edec4239SQuentin Colombet //===----------------------------------------------------------------------===// 2142edec4239SQuentin Colombet // TableGen'd enum attribute definitions 2143edec4239SQuentin Colombet //===----------------------------------------------------------------------===// 2144edec4239SQuentin Colombet 2145edec4239SQuentin Colombet #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc" 2146