14c3db258SAdam Siemieniuk //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===// 24c3db258SAdam Siemieniuk // 34c3db258SAdam Siemieniuk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44c3db258SAdam Siemieniuk // See https://llvm.org/LICENSE.txt for license information. 54c3db258SAdam Siemieniuk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64c3db258SAdam Siemieniuk // 74c3db258SAdam Siemieniuk //===----------------------------------------------------------------------===// 84c3db258SAdam Siemieniuk 94c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Passes.h" 104c3db258SAdam Siemieniuk 114c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/IR/Linalg.h" 124c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 134c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Utils/Utils.h" 144c3db258SAdam Siemieniuk #include "mlir/IR/PatternMatch.h" 154c3db258SAdam Siemieniuk #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 164c3db258SAdam Siemieniuk #include "llvm/ADT/SmallVector.h" 174c3db258SAdam Siemieniuk #include "llvm/ADT/TypeSwitch.h" 184c3db258SAdam Siemieniuk 194c3db258SAdam Siemieniuk #include <optional> 204c3db258SAdam Siemieniuk 214c3db258SAdam Siemieniuk namespace mlir { 224c3db258SAdam Siemieniuk #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL 234c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Passes.h.inc" 244c3db258SAdam Siemieniuk } // namespace mlir 254c3db258SAdam Siemieniuk 264c3db258SAdam Siemieniuk using namespace mlir; 274c3db258SAdam Siemieniuk using namespace mlir::linalg; 284c3db258SAdam Siemieniuk 294c3db258SAdam Siemieniuk /// Return constant range span or nullopt, otherwise. 304c3db258SAdam Siemieniuk static std::optional<int64_t> getConstantRange(const Range &range) { 314c3db258SAdam Siemieniuk std::optional<int64_t> stride = getConstantIntValue(range.stride); 324c3db258SAdam Siemieniuk if (!stride || *stride != 1) 334c3db258SAdam Siemieniuk return std::nullopt; 344c3db258SAdam Siemieniuk std::optional<int64_t> offset = getConstantIntValue(range.offset); 354c3db258SAdam Siemieniuk if (!offset) 364c3db258SAdam Siemieniuk return std::nullopt; 374c3db258SAdam Siemieniuk std::optional<int64_t> size = getConstantIntValue(range.size); 384c3db258SAdam Siemieniuk if (!size) 394c3db258SAdam Siemieniuk return std::nullopt; 404c3db258SAdam Siemieniuk return (*size - *offset); 414c3db258SAdam Siemieniuk } 424c3db258SAdam Siemieniuk 434c3db258SAdam Siemieniuk /// Return true if all dimensions are fully divisible by the respective tiles. 444c3db258SAdam Siemieniuk static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, 454c3db258SAdam Siemieniuk ArrayRef<OpFoldResult> tiles, 464c3db258SAdam Siemieniuk ArrayRef<int64_t> dims) { 474c3db258SAdam Siemieniuk if (dims.size() != tiles.size() || tiles.empty()) 484c3db258SAdam Siemieniuk return false; 494c3db258SAdam Siemieniuk 504c3db258SAdam Siemieniuk FailureOr<ContractionDimensions> contractDims = 514c3db258SAdam Siemieniuk inferContractionDims(linalgOp); 524c3db258SAdam Siemieniuk if (failed(contractDims)) 534c3db258SAdam Siemieniuk return false; 544c3db258SAdam Siemieniuk unsigned batchDimsOffset = contractDims->batch.size(); 554c3db258SAdam Siemieniuk 564c3db258SAdam Siemieniuk // Skip the batch dimension if present. 574c3db258SAdam Siemieniuk // Offset all dimensions accordingly. 58*9cbc1f29SHan-Chung Wang SmallVector<int64_t, 3> offsetDims(dims); 594c3db258SAdam Siemieniuk for (size_t i = 0; i < offsetDims.size(); i++) 604c3db258SAdam Siemieniuk offsetDims[i] += batchDimsOffset; 614c3db258SAdam Siemieniuk 624c3db258SAdam Siemieniuk auto tileOp = cast<TilingInterface>(linalgOp.getOperation()); 634c3db258SAdam Siemieniuk OpBuilder builder(tileOp); 644c3db258SAdam Siemieniuk OpBuilder::InsertionGuard guard(builder); 654c3db258SAdam Siemieniuk SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder); 664c3db258SAdam Siemieniuk 674c3db258SAdam Siemieniuk for (auto dim : llvm::enumerate(offsetDims)) { 684c3db258SAdam Siemieniuk if (dim.value() >= static_cast<int64_t>(iterationDomain.size())) 694c3db258SAdam Siemieniuk return false; 704c3db258SAdam Siemieniuk 714c3db258SAdam Siemieniuk std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]); 724c3db258SAdam Siemieniuk std::optional<int64_t> rangeOnDim = 734c3db258SAdam Siemieniuk getConstantRange(iterationDomain[dim.value()]); 744c3db258SAdam Siemieniuk 754c3db258SAdam Siemieniuk // If the tile factor or the range are non-constant, the tile size is 764c3db258SAdam Siemieniuk // considered to be invalid. 774c3db258SAdam Siemieniuk if (!tileSize || !rangeOnDim) 784c3db258SAdam Siemieniuk return false; 794c3db258SAdam Siemieniuk 804c3db258SAdam Siemieniuk // The dimension must be fully divisible by the tile. 814c3db258SAdam Siemieniuk if (*rangeOnDim % *tileSize != 0) 824c3db258SAdam Siemieniuk return false; 834c3db258SAdam Siemieniuk } 844c3db258SAdam Siemieniuk 854c3db258SAdam Siemieniuk return true; 864c3db258SAdam Siemieniuk } 874c3db258SAdam Siemieniuk 884c3db258SAdam Siemieniuk /// Return failure or packed matmul with one of its operands transposed. 894c3db258SAdam Siemieniuk static FailureOr<PackTransposeResult> 904c3db258SAdam Siemieniuk transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, 914c3db258SAdam Siemieniuk tensor::PackOp packOp, AffineMap operandMap, 924c3db258SAdam Siemieniuk ArrayRef<unsigned> blocksStartDimPos, 934c3db258SAdam Siemieniuk bool transposeOuterBlocks, bool transposeInnerBlocks) { 944c3db258SAdam Siemieniuk assert(operandMap.getNumDims() >= 4 && 954c3db258SAdam Siemieniuk "expected at least 4D prepacked matmul"); 964c3db258SAdam Siemieniuk assert(blocksStartDimPos.size() >= 2 && 974c3db258SAdam Siemieniuk "expected starting outer and inner block positions"); 984c3db258SAdam Siemieniuk 994c3db258SAdam Siemieniuk // Bias toward innermost dimensions. 1004c3db258SAdam Siemieniuk unsigned outerBlockPos = operandMap.getNumResults() - 4; 1014c3db258SAdam Siemieniuk unsigned innerBlockPos = operandMap.getNumResults() - 2; 1024c3db258SAdam Siemieniuk 1034c3db258SAdam Siemieniuk // Transpose control options define the desired block and element layout. 1044c3db258SAdam Siemieniuk // Block transposition (outer dimensions) or element transposition (inner 1054c3db258SAdam Siemieniuk // dimensions) may not be necessary depending on the original matmul data 1064c3db258SAdam Siemieniuk // layout. 1074c3db258SAdam Siemieniuk bool isOuterTransposed = 1084c3db258SAdam Siemieniuk operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2]; 1094c3db258SAdam Siemieniuk bool isInnerTransposed = 1104c3db258SAdam Siemieniuk operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back(); 1114c3db258SAdam Siemieniuk 1124c3db258SAdam Siemieniuk // Transpose only the dimensions that need that to conform to the provided 1134c3db258SAdam Siemieniuk // transpotion settings. 114*9cbc1f29SHan-Chung Wang SmallVector<int64_t> innerPerm = {0, 1}; 1154c3db258SAdam Siemieniuk if (isInnerTransposed != transposeInnerBlocks) 1164c3db258SAdam Siemieniuk innerPerm = {1, 0}; 117*9cbc1f29SHan-Chung Wang SmallVector<int64_t> outerPerm = {0, 1}; 1184c3db258SAdam Siemieniuk if (isOuterTransposed != transposeOuterBlocks) 1194c3db258SAdam Siemieniuk outerPerm = {1, 0}; 1204c3db258SAdam Siemieniuk 1214c3db258SAdam Siemieniuk // Leave the outer dimensions, like batch, unchanged by offsetting all 1224c3db258SAdam Siemieniuk // outer dimensions permutations. 1234c3db258SAdam Siemieniuk SmallVector<int64_t> offsetPerms; 1244c3db258SAdam Siemieniuk for (auto i : llvm::seq(0u, outerBlockPos)) 1254c3db258SAdam Siemieniuk offsetPerms.push_back(i); 1264c3db258SAdam Siemieniuk for (auto perm : outerPerm) 1274c3db258SAdam Siemieniuk offsetPerms.push_back(perm + outerBlockPos); 1284c3db258SAdam Siemieniuk outerPerm = offsetPerms; 1294c3db258SAdam Siemieniuk 1304c3db258SAdam Siemieniuk FailureOr<PackTransposeResult> packTransposedMatmul = 1314c3db258SAdam Siemieniuk packTranspose(rewriter, packOp, linalgOp, 1324c3db258SAdam Siemieniuk /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm); 1334c3db258SAdam Siemieniuk 1344c3db258SAdam Siemieniuk return packTransposedMatmul; 1354c3db258SAdam Siemieniuk } 1364c3db258SAdam Siemieniuk 1374c3db258SAdam Siemieniuk /// Pack a matmul operation into blocked 4D layout. 1384c3db258SAdam Siemieniuk FailureOr<PackResult> 1394c3db258SAdam Siemieniuk linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, 1404c3db258SAdam Siemieniuk const ControlBlockPackMatmulFn &controlPackMatmul) { 1414c3db258SAdam Siemieniuk if (linalgOp.hasPureBufferSemantics()) 1424c3db258SAdam Siemieniuk return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); 1434c3db258SAdam Siemieniuk 1444c3db258SAdam Siemieniuk std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp); 1454c3db258SAdam Siemieniuk if (!options) 1464c3db258SAdam Siemieniuk return rewriter.notifyMatchFailure(linalgOp, "invalid packing options"); 1474c3db258SAdam Siemieniuk 1484c3db258SAdam Siemieniuk if (options->blockFactors.size() != 3) 1494c3db258SAdam Siemieniuk return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors"); 1504c3db258SAdam Siemieniuk 1514c3db258SAdam Siemieniuk SmallVector<OpFoldResult> mnkTiles = 1524c3db258SAdam Siemieniuk getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors)); 1534c3db258SAdam Siemieniuk 1544c3db258SAdam Siemieniuk // If padding is disabled, make sure that dimensions can be packed cleanly. 1554c3db258SAdam Siemieniuk if (!options->allowPadding && 1564c3db258SAdam Siemieniuk !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) { 1574c3db258SAdam Siemieniuk return rewriter.notifyMatchFailure(linalgOp, 1584c3db258SAdam Siemieniuk "expect packing full tiles only"); 1594c3db258SAdam Siemieniuk } 1604c3db258SAdam Siemieniuk 1614c3db258SAdam Siemieniuk OpBuilder::InsertionGuard guard(rewriter); 1624c3db258SAdam Siemieniuk // The op is replaced, we need to set the insertion point after it. 1634c3db258SAdam Siemieniuk rewriter.setInsertionPointAfter(linalgOp); 1644c3db258SAdam Siemieniuk 1654c3db258SAdam Siemieniuk // Pack the matmul operation into blocked layout with two levels of 1664c3db258SAdam Siemieniuk // subdivision: 1674c3db258SAdam Siemieniuk // - major 2D blocks - outer dimensions, consist of minor blocks 1684c3db258SAdam Siemieniuk // - minor 2D blocks - inner dimensions, consist of scalar elements 1694c3db258SAdam Siemieniuk FailureOr<PackResult> packedMatmul = packMatmulGreedily( 1704c3db258SAdam Siemieniuk rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf, 1714c3db258SAdam Siemieniuk options->mnkOrder); 1724c3db258SAdam Siemieniuk if (failed(packedMatmul)) 1734c3db258SAdam Siemieniuk return failure(); 1744c3db258SAdam Siemieniuk 1754c3db258SAdam Siemieniuk assert(packedMatmul->packOps.size() == 3 && 1764c3db258SAdam Siemieniuk "invalid number of pack ops after matmul packing"); 1774c3db258SAdam Siemieniuk assert(packedMatmul->unPackOps.size() == 1 && 1784c3db258SAdam Siemieniuk "invalid number of unpack ops after matmul packing"); 1794c3db258SAdam Siemieniuk 1804c3db258SAdam Siemieniuk FailureOr<ContractionDimensions> contractDims = 1814c3db258SAdam Siemieniuk inferContractionDims(packedMatmul->packedLinalgOp); 1824c3db258SAdam Siemieniuk if (failed(contractDims)) 1834c3db258SAdam Siemieniuk return failure(); 1844c3db258SAdam Siemieniuk 1854c3db258SAdam Siemieniuk auto genericOp = 1864c3db258SAdam Siemieniuk dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation()); 1874c3db258SAdam Siemieniuk SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray(); 1884c3db258SAdam Siemieniuk 1894c3db258SAdam Siemieniuk // Transpose LHS matrix according to the options. 1904c3db258SAdam Siemieniuk FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul( 1914c3db258SAdam Siemieniuk rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0], 1924c3db258SAdam Siemieniuk contractDims->m, options->lhsTransposeOuterBlocks, 1934c3db258SAdam Siemieniuk options->lhsTransposeInnerBlocks); 1944c3db258SAdam Siemieniuk if (failed(packedLhs)) 1954c3db258SAdam Siemieniuk return failure(); 1964c3db258SAdam Siemieniuk 1974c3db258SAdam Siemieniuk // Update results. 1984c3db258SAdam Siemieniuk packedMatmul->packOps[0] = packedLhs->transposedPackOp; 1994c3db258SAdam Siemieniuk packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp; 2004c3db258SAdam Siemieniuk 2014c3db258SAdam Siemieniuk // Transpose RHS matrix according to the options. 2024c3db258SAdam Siemieniuk FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul( 2034c3db258SAdam Siemieniuk rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1], 2044c3db258SAdam Siemieniuk contractDims->k, options->rhsTransposeOuterBlocks, 2054c3db258SAdam Siemieniuk options->rhsTransposeInnerBlocks); 2064c3db258SAdam Siemieniuk if (failed(packedRhs)) 2074c3db258SAdam Siemieniuk return failure(); 2084c3db258SAdam Siemieniuk 2094c3db258SAdam Siemieniuk // Update results. 2104c3db258SAdam Siemieniuk packedMatmul->packOps[1] = packedRhs->transposedPackOp; 2114c3db258SAdam Siemieniuk packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp; 2124c3db258SAdam Siemieniuk 2134c3db258SAdam Siemieniuk return packedMatmul; 2144c3db258SAdam Siemieniuk } 2154c3db258SAdam Siemieniuk 2164c3db258SAdam Siemieniuk namespace { 2174c3db258SAdam Siemieniuk template <typename OpTy> 2184c3db258SAdam Siemieniuk struct BlockPackMatmul : public OpRewritePattern<OpTy> { 2194c3db258SAdam Siemieniuk BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun, 2204c3db258SAdam Siemieniuk PatternBenefit benefit = 1) 2214c3db258SAdam Siemieniuk : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {} 2224c3db258SAdam Siemieniuk 2234c3db258SAdam Siemieniuk LogicalResult matchAndRewrite(OpTy linalgOp, 2244c3db258SAdam Siemieniuk PatternRewriter &rewriter) const override { 2254c3db258SAdam Siemieniuk FailureOr<PackResult> packedMatmul = 2264c3db258SAdam Siemieniuk blockPackMatmul(rewriter, linalgOp, controlFn); 2274c3db258SAdam Siemieniuk if (failed(packedMatmul)) 2284c3db258SAdam Siemieniuk return failure(); 2294c3db258SAdam Siemieniuk return success(); 2304c3db258SAdam Siemieniuk } 2314c3db258SAdam Siemieniuk 2324c3db258SAdam Siemieniuk private: 2334c3db258SAdam Siemieniuk ControlBlockPackMatmulFn controlFn; 2344c3db258SAdam Siemieniuk }; 2354c3db258SAdam Siemieniuk 2364c3db258SAdam Siemieniuk template <> 2374c3db258SAdam Siemieniuk struct BlockPackMatmul<linalg::GenericOp> 2384c3db258SAdam Siemieniuk : public OpRewritePattern<linalg::GenericOp> { 2394c3db258SAdam Siemieniuk BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun, 2404c3db258SAdam Siemieniuk PatternBenefit benefit = 1) 2414c3db258SAdam Siemieniuk : OpRewritePattern<linalg::GenericOp>(context, benefit), 2424c3db258SAdam Siemieniuk controlFn(std::move(fun)) {} 2434c3db258SAdam Siemieniuk 2444c3db258SAdam Siemieniuk LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, 2454c3db258SAdam Siemieniuk PatternRewriter &rewriter) const override { 2464c3db258SAdam Siemieniuk // Match suitable generics. 247d776346aSAdam Siemieniuk if (!linalg::isaContractionOpInterface(linalgOp)) { 2484c3db258SAdam Siemieniuk return rewriter.notifyMatchFailure(linalgOp, "not a contraction"); 2494c3db258SAdam Siemieniuk } 2504c3db258SAdam Siemieniuk 2514c3db258SAdam Siemieniuk using MapList = ArrayRef<ArrayRef<AffineExpr>>; 2524c3db258SAdam Siemieniuk auto infer = [&](MapList m) { 2534c3db258SAdam Siemieniuk return AffineMap::inferFromExprList(m, linalgOp.getContext()); 2544c3db258SAdam Siemieniuk }; 2554c3db258SAdam Siemieniuk 2564c3db258SAdam Siemieniuk AffineExpr i, j, k; 2574c3db258SAdam Siemieniuk bindDims(linalgOp->getContext(), i, j, k); 2584c3db258SAdam Siemieniuk SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray(); 2594c3db258SAdam Siemieniuk 2604c3db258SAdam Siemieniuk // For now, only match simple matmuls. 2614c3db258SAdam Siemieniuk if (!(maps == infer({{i, k}, {k, j}, {i, j}}) || 2624c3db258SAdam Siemieniuk maps == infer({{k, i}, {k, j}, {i, j}}) || 2634c3db258SAdam Siemieniuk maps == infer({{i, k}, {j, k}, {i, j}}))) { 2644c3db258SAdam Siemieniuk return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul"); 2654c3db258SAdam Siemieniuk } 2664c3db258SAdam Siemieniuk 2674c3db258SAdam Siemieniuk FailureOr<PackResult> packedMatmul = 2684c3db258SAdam Siemieniuk blockPackMatmul(rewriter, linalgOp, controlFn); 2694c3db258SAdam Siemieniuk if (failed(packedMatmul)) 2704c3db258SAdam Siemieniuk return failure(); 2714c3db258SAdam Siemieniuk return success(); 2724c3db258SAdam Siemieniuk } 2734c3db258SAdam Siemieniuk 2744c3db258SAdam Siemieniuk private: 2754c3db258SAdam Siemieniuk ControlBlockPackMatmulFn controlFn; 2764c3db258SAdam Siemieniuk }; 2774c3db258SAdam Siemieniuk 2784c3db258SAdam Siemieniuk /// Convert linalg matmul ops to block layout and back. 2794c3db258SAdam Siemieniuk struct LinalgBlockPackMatmul 2804c3db258SAdam Siemieniuk : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> { 2814c3db258SAdam Siemieniuk using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase; 2824c3db258SAdam Siemieniuk 2834c3db258SAdam Siemieniuk void runOnOperation() override { 2844c3db258SAdam Siemieniuk Operation *op = getOperation(); 2854c3db258SAdam Siemieniuk RewritePatternSet patterns(&getContext()); 2864c3db258SAdam Siemieniuk 2874c3db258SAdam Siemieniuk ControlBlockPackMatmulFn controlFn = 2884c3db258SAdam Siemieniuk [&](linalg::LinalgOp op) -> BlockPackMatmulOptions { 2894c3db258SAdam Siemieniuk BlockPackMatmulOptions options; 2904c3db258SAdam Siemieniuk options.blockFactors = SmallVector<int64_t>{*blockFactors}; 2914c3db258SAdam Siemieniuk options.allowPadding = allowPadding; 2924c3db258SAdam Siemieniuk options.mnkPaddedSizesNextMultipleOf = 2934c3db258SAdam Siemieniuk SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf}; 2944c3db258SAdam Siemieniuk if (!mnkOrder.empty()) 2954c3db258SAdam Siemieniuk options.mnkOrder = SmallVector<int64_t>{*mnkOrder}; 2964c3db258SAdam Siemieniuk options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks; 2974c3db258SAdam Siemieniuk options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks; 2984c3db258SAdam Siemieniuk options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks; 2994c3db258SAdam Siemieniuk options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks; 3004c3db258SAdam Siemieniuk return options; 3014c3db258SAdam Siemieniuk }; 3024c3db258SAdam Siemieniuk 3034c3db258SAdam Siemieniuk linalg::populateBlockPackMatmulPatterns(patterns, controlFn); 30409dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns)))) 3054c3db258SAdam Siemieniuk return signalPassFailure(); 3064c3db258SAdam Siemieniuk } 3074c3db258SAdam Siemieniuk }; 3084c3db258SAdam Siemieniuk } // namespace 3094c3db258SAdam Siemieniuk 3104c3db258SAdam Siemieniuk void linalg::populateBlockPackMatmulPatterns( 3114c3db258SAdam Siemieniuk RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { 3124c3db258SAdam Siemieniuk patterns.add<BlockPackMatmul<linalg::GenericOp>, 3134c3db258SAdam Siemieniuk BlockPackMatmul<linalg::MatmulOp>, 3144c3db258SAdam Siemieniuk BlockPackMatmul<linalg::BatchMatmulOp>, 3154c3db258SAdam Siemieniuk BlockPackMatmul<linalg::MatmulTransposeAOp>, 3164c3db258SAdam Siemieniuk BlockPackMatmul<linalg::BatchMatmulTransposeAOp>, 3174c3db258SAdam Siemieniuk BlockPackMatmul<linalg::MatmulTransposeBOp>, 3184c3db258SAdam Siemieniuk BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>( 3194c3db258SAdam Siemieniuk patterns.getContext(), controlFn); 3204c3db258SAdam Siemieniuk } 321