xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
14c3db258SAdam Siemieniuk //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
24c3db258SAdam Siemieniuk //
34c3db258SAdam Siemieniuk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44c3db258SAdam Siemieniuk // See https://llvm.org/LICENSE.txt for license information.
54c3db258SAdam Siemieniuk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64c3db258SAdam Siemieniuk //
74c3db258SAdam Siemieniuk //===----------------------------------------------------------------------===//
84c3db258SAdam Siemieniuk 
94c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Passes.h"
104c3db258SAdam Siemieniuk 
114c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/IR/Linalg.h"
124c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
134c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Utils/Utils.h"
144c3db258SAdam Siemieniuk #include "mlir/IR/PatternMatch.h"
154c3db258SAdam Siemieniuk #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
164c3db258SAdam Siemieniuk #include "llvm/ADT/SmallVector.h"
174c3db258SAdam Siemieniuk #include "llvm/ADT/TypeSwitch.h"
184c3db258SAdam Siemieniuk 
194c3db258SAdam Siemieniuk #include <optional>
204c3db258SAdam Siemieniuk 
214c3db258SAdam Siemieniuk namespace mlir {
224c3db258SAdam Siemieniuk #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
234c3db258SAdam Siemieniuk #include "mlir/Dialect/Linalg/Passes.h.inc"
244c3db258SAdam Siemieniuk } // namespace mlir
254c3db258SAdam Siemieniuk 
264c3db258SAdam Siemieniuk using namespace mlir;
274c3db258SAdam Siemieniuk using namespace mlir::linalg;
284c3db258SAdam Siemieniuk 
294c3db258SAdam Siemieniuk /// Return constant range span or nullopt, otherwise.
304c3db258SAdam Siemieniuk static std::optional<int64_t> getConstantRange(const Range &range) {
314c3db258SAdam Siemieniuk   std::optional<int64_t> stride = getConstantIntValue(range.stride);
324c3db258SAdam Siemieniuk   if (!stride || *stride != 1)
334c3db258SAdam Siemieniuk     return std::nullopt;
344c3db258SAdam Siemieniuk   std::optional<int64_t> offset = getConstantIntValue(range.offset);
354c3db258SAdam Siemieniuk   if (!offset)
364c3db258SAdam Siemieniuk     return std::nullopt;
374c3db258SAdam Siemieniuk   std::optional<int64_t> size = getConstantIntValue(range.size);
384c3db258SAdam Siemieniuk   if (!size)
394c3db258SAdam Siemieniuk     return std::nullopt;
404c3db258SAdam Siemieniuk   return (*size - *offset);
414c3db258SAdam Siemieniuk }
424c3db258SAdam Siemieniuk 
434c3db258SAdam Siemieniuk /// Return true if all dimensions are fully divisible by the respective tiles.
444c3db258SAdam Siemieniuk static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
454c3db258SAdam Siemieniuk                                     ArrayRef<OpFoldResult> tiles,
464c3db258SAdam Siemieniuk                                     ArrayRef<int64_t> dims) {
474c3db258SAdam Siemieniuk   if (dims.size() != tiles.size() || tiles.empty())
484c3db258SAdam Siemieniuk     return false;
494c3db258SAdam Siemieniuk 
504c3db258SAdam Siemieniuk   FailureOr<ContractionDimensions> contractDims =
514c3db258SAdam Siemieniuk       inferContractionDims(linalgOp);
524c3db258SAdam Siemieniuk   if (failed(contractDims))
534c3db258SAdam Siemieniuk     return false;
544c3db258SAdam Siemieniuk   unsigned batchDimsOffset = contractDims->batch.size();
554c3db258SAdam Siemieniuk 
564c3db258SAdam Siemieniuk   // Skip the batch dimension if present.
574c3db258SAdam Siemieniuk   // Offset all dimensions accordingly.
58*9cbc1f29SHan-Chung Wang   SmallVector<int64_t, 3> offsetDims(dims);
594c3db258SAdam Siemieniuk   for (size_t i = 0; i < offsetDims.size(); i++)
604c3db258SAdam Siemieniuk     offsetDims[i] += batchDimsOffset;
614c3db258SAdam Siemieniuk 
624c3db258SAdam Siemieniuk   auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
634c3db258SAdam Siemieniuk   OpBuilder builder(tileOp);
644c3db258SAdam Siemieniuk   OpBuilder::InsertionGuard guard(builder);
654c3db258SAdam Siemieniuk   SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
664c3db258SAdam Siemieniuk 
674c3db258SAdam Siemieniuk   for (auto dim : llvm::enumerate(offsetDims)) {
684c3db258SAdam Siemieniuk     if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
694c3db258SAdam Siemieniuk       return false;
704c3db258SAdam Siemieniuk 
714c3db258SAdam Siemieniuk     std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
724c3db258SAdam Siemieniuk     std::optional<int64_t> rangeOnDim =
734c3db258SAdam Siemieniuk         getConstantRange(iterationDomain[dim.value()]);
744c3db258SAdam Siemieniuk 
754c3db258SAdam Siemieniuk     // If the tile factor or the range are non-constant, the tile size is
764c3db258SAdam Siemieniuk     // considered to be invalid.
774c3db258SAdam Siemieniuk     if (!tileSize || !rangeOnDim)
784c3db258SAdam Siemieniuk       return false;
794c3db258SAdam Siemieniuk 
804c3db258SAdam Siemieniuk     // The dimension must be fully divisible by the tile.
814c3db258SAdam Siemieniuk     if (*rangeOnDim % *tileSize != 0)
824c3db258SAdam Siemieniuk       return false;
834c3db258SAdam Siemieniuk   }
844c3db258SAdam Siemieniuk 
854c3db258SAdam Siemieniuk   return true;
864c3db258SAdam Siemieniuk }
874c3db258SAdam Siemieniuk 
884c3db258SAdam Siemieniuk /// Return failure or packed matmul with one of its operands transposed.
894c3db258SAdam Siemieniuk static FailureOr<PackTransposeResult>
904c3db258SAdam Siemieniuk transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
914c3db258SAdam Siemieniuk                       tensor::PackOp packOp, AffineMap operandMap,
924c3db258SAdam Siemieniuk                       ArrayRef<unsigned> blocksStartDimPos,
934c3db258SAdam Siemieniuk                       bool transposeOuterBlocks, bool transposeInnerBlocks) {
944c3db258SAdam Siemieniuk   assert(operandMap.getNumDims() >= 4 &&
954c3db258SAdam Siemieniuk          "expected at least 4D prepacked matmul");
964c3db258SAdam Siemieniuk   assert(blocksStartDimPos.size() >= 2 &&
974c3db258SAdam Siemieniuk          "expected starting outer and inner block positions");
984c3db258SAdam Siemieniuk 
994c3db258SAdam Siemieniuk   // Bias toward innermost dimensions.
1004c3db258SAdam Siemieniuk   unsigned outerBlockPos = operandMap.getNumResults() - 4;
1014c3db258SAdam Siemieniuk   unsigned innerBlockPos = operandMap.getNumResults() - 2;
1024c3db258SAdam Siemieniuk 
1034c3db258SAdam Siemieniuk   // Transpose control options define the desired block and element layout.
1044c3db258SAdam Siemieniuk   // Block transposition (outer dimensions) or element transposition (inner
1054c3db258SAdam Siemieniuk   // dimensions) may not be necessary depending on the original matmul data
1064c3db258SAdam Siemieniuk   // layout.
1074c3db258SAdam Siemieniuk   bool isOuterTransposed =
1084c3db258SAdam Siemieniuk       operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
1094c3db258SAdam Siemieniuk   bool isInnerTransposed =
1104c3db258SAdam Siemieniuk       operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
1114c3db258SAdam Siemieniuk 
1124c3db258SAdam Siemieniuk   // Transpose only the dimensions that need that to conform to the provided
1134c3db258SAdam Siemieniuk   // transpotion settings.
114*9cbc1f29SHan-Chung Wang   SmallVector<int64_t> innerPerm = {0, 1};
1154c3db258SAdam Siemieniuk   if (isInnerTransposed != transposeInnerBlocks)
1164c3db258SAdam Siemieniuk     innerPerm = {1, 0};
117*9cbc1f29SHan-Chung Wang   SmallVector<int64_t> outerPerm = {0, 1};
1184c3db258SAdam Siemieniuk   if (isOuterTransposed != transposeOuterBlocks)
1194c3db258SAdam Siemieniuk     outerPerm = {1, 0};
1204c3db258SAdam Siemieniuk 
1214c3db258SAdam Siemieniuk   // Leave the outer dimensions, like batch, unchanged by offsetting all
1224c3db258SAdam Siemieniuk   // outer dimensions permutations.
1234c3db258SAdam Siemieniuk   SmallVector<int64_t> offsetPerms;
1244c3db258SAdam Siemieniuk   for (auto i : llvm::seq(0u, outerBlockPos))
1254c3db258SAdam Siemieniuk     offsetPerms.push_back(i);
1264c3db258SAdam Siemieniuk   for (auto perm : outerPerm)
1274c3db258SAdam Siemieniuk     offsetPerms.push_back(perm + outerBlockPos);
1284c3db258SAdam Siemieniuk   outerPerm = offsetPerms;
1294c3db258SAdam Siemieniuk 
1304c3db258SAdam Siemieniuk   FailureOr<PackTransposeResult> packTransposedMatmul =
1314c3db258SAdam Siemieniuk       packTranspose(rewriter, packOp, linalgOp,
1324c3db258SAdam Siemieniuk                     /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
1334c3db258SAdam Siemieniuk 
1344c3db258SAdam Siemieniuk   return packTransposedMatmul;
1354c3db258SAdam Siemieniuk }
1364c3db258SAdam Siemieniuk 
1374c3db258SAdam Siemieniuk /// Pack a matmul operation into blocked 4D layout.
1384c3db258SAdam Siemieniuk FailureOr<PackResult>
1394c3db258SAdam Siemieniuk linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1404c3db258SAdam Siemieniuk                         const ControlBlockPackMatmulFn &controlPackMatmul) {
1414c3db258SAdam Siemieniuk   if (linalgOp.hasPureBufferSemantics())
1424c3db258SAdam Siemieniuk     return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
1434c3db258SAdam Siemieniuk 
1444c3db258SAdam Siemieniuk   std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
1454c3db258SAdam Siemieniuk   if (!options)
1464c3db258SAdam Siemieniuk     return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
1474c3db258SAdam Siemieniuk 
1484c3db258SAdam Siemieniuk   if (options->blockFactors.size() != 3)
1494c3db258SAdam Siemieniuk     return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
1504c3db258SAdam Siemieniuk 
1514c3db258SAdam Siemieniuk   SmallVector<OpFoldResult> mnkTiles =
1524c3db258SAdam Siemieniuk       getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
1534c3db258SAdam Siemieniuk 
1544c3db258SAdam Siemieniuk   // If padding is disabled, make sure that dimensions can be packed cleanly.
1554c3db258SAdam Siemieniuk   if (!options->allowPadding &&
1564c3db258SAdam Siemieniuk       !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
1574c3db258SAdam Siemieniuk     return rewriter.notifyMatchFailure(linalgOp,
1584c3db258SAdam Siemieniuk                                        "expect packing full tiles only");
1594c3db258SAdam Siemieniuk   }
1604c3db258SAdam Siemieniuk 
1614c3db258SAdam Siemieniuk   OpBuilder::InsertionGuard guard(rewriter);
1624c3db258SAdam Siemieniuk   // The op is replaced, we need to set the insertion point after it.
1634c3db258SAdam Siemieniuk   rewriter.setInsertionPointAfter(linalgOp);
1644c3db258SAdam Siemieniuk 
1654c3db258SAdam Siemieniuk   // Pack the matmul operation into blocked layout with two levels of
1664c3db258SAdam Siemieniuk   // subdivision:
1674c3db258SAdam Siemieniuk   //   - major 2D blocks - outer dimensions, consist of minor blocks
1684c3db258SAdam Siemieniuk   //   - minor 2D blocks - inner dimensions, consist of scalar elements
1694c3db258SAdam Siemieniuk   FailureOr<PackResult> packedMatmul = packMatmulGreedily(
1704c3db258SAdam Siemieniuk       rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
1714c3db258SAdam Siemieniuk       options->mnkOrder);
1724c3db258SAdam Siemieniuk   if (failed(packedMatmul))
1734c3db258SAdam Siemieniuk     return failure();
1744c3db258SAdam Siemieniuk 
1754c3db258SAdam Siemieniuk   assert(packedMatmul->packOps.size() == 3 &&
1764c3db258SAdam Siemieniuk          "invalid number of pack ops after matmul packing");
1774c3db258SAdam Siemieniuk   assert(packedMatmul->unPackOps.size() == 1 &&
1784c3db258SAdam Siemieniuk          "invalid number of unpack ops after matmul packing");
1794c3db258SAdam Siemieniuk 
1804c3db258SAdam Siemieniuk   FailureOr<ContractionDimensions> contractDims =
1814c3db258SAdam Siemieniuk       inferContractionDims(packedMatmul->packedLinalgOp);
1824c3db258SAdam Siemieniuk   if (failed(contractDims))
1834c3db258SAdam Siemieniuk     return failure();
1844c3db258SAdam Siemieniuk 
1854c3db258SAdam Siemieniuk   auto genericOp =
1864c3db258SAdam Siemieniuk       dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
1874c3db258SAdam Siemieniuk   SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
1884c3db258SAdam Siemieniuk 
1894c3db258SAdam Siemieniuk   // Transpose LHS matrix according to the options.
1904c3db258SAdam Siemieniuk   FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
1914c3db258SAdam Siemieniuk       rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
1924c3db258SAdam Siemieniuk       contractDims->m, options->lhsTransposeOuterBlocks,
1934c3db258SAdam Siemieniuk       options->lhsTransposeInnerBlocks);
1944c3db258SAdam Siemieniuk   if (failed(packedLhs))
1954c3db258SAdam Siemieniuk     return failure();
1964c3db258SAdam Siemieniuk 
1974c3db258SAdam Siemieniuk   // Update results.
1984c3db258SAdam Siemieniuk   packedMatmul->packOps[0] = packedLhs->transposedPackOp;
1994c3db258SAdam Siemieniuk   packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
2004c3db258SAdam Siemieniuk 
2014c3db258SAdam Siemieniuk   // Transpose RHS matrix according to the options.
2024c3db258SAdam Siemieniuk   FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
2034c3db258SAdam Siemieniuk       rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
2044c3db258SAdam Siemieniuk       contractDims->k, options->rhsTransposeOuterBlocks,
2054c3db258SAdam Siemieniuk       options->rhsTransposeInnerBlocks);
2064c3db258SAdam Siemieniuk   if (failed(packedRhs))
2074c3db258SAdam Siemieniuk     return failure();
2084c3db258SAdam Siemieniuk 
2094c3db258SAdam Siemieniuk   // Update results.
2104c3db258SAdam Siemieniuk   packedMatmul->packOps[1] = packedRhs->transposedPackOp;
2114c3db258SAdam Siemieniuk   packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
2124c3db258SAdam Siemieniuk 
2134c3db258SAdam Siemieniuk   return packedMatmul;
2144c3db258SAdam Siemieniuk }
2154c3db258SAdam Siemieniuk 
2164c3db258SAdam Siemieniuk namespace {
2174c3db258SAdam Siemieniuk template <typename OpTy>
2184c3db258SAdam Siemieniuk struct BlockPackMatmul : public OpRewritePattern<OpTy> {
2194c3db258SAdam Siemieniuk   BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
2204c3db258SAdam Siemieniuk                   PatternBenefit benefit = 1)
2214c3db258SAdam Siemieniuk       : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
2224c3db258SAdam Siemieniuk 
2234c3db258SAdam Siemieniuk   LogicalResult matchAndRewrite(OpTy linalgOp,
2244c3db258SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
2254c3db258SAdam Siemieniuk     FailureOr<PackResult> packedMatmul =
2264c3db258SAdam Siemieniuk         blockPackMatmul(rewriter, linalgOp, controlFn);
2274c3db258SAdam Siemieniuk     if (failed(packedMatmul))
2284c3db258SAdam Siemieniuk       return failure();
2294c3db258SAdam Siemieniuk     return success();
2304c3db258SAdam Siemieniuk   }
2314c3db258SAdam Siemieniuk 
2324c3db258SAdam Siemieniuk private:
2334c3db258SAdam Siemieniuk   ControlBlockPackMatmulFn controlFn;
2344c3db258SAdam Siemieniuk };
2354c3db258SAdam Siemieniuk 
2364c3db258SAdam Siemieniuk template <>
2374c3db258SAdam Siemieniuk struct BlockPackMatmul<linalg::GenericOp>
2384c3db258SAdam Siemieniuk     : public OpRewritePattern<linalg::GenericOp> {
2394c3db258SAdam Siemieniuk   BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
2404c3db258SAdam Siemieniuk                   PatternBenefit benefit = 1)
2414c3db258SAdam Siemieniuk       : OpRewritePattern<linalg::GenericOp>(context, benefit),
2424c3db258SAdam Siemieniuk         controlFn(std::move(fun)) {}
2434c3db258SAdam Siemieniuk 
2444c3db258SAdam Siemieniuk   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
2454c3db258SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
2464c3db258SAdam Siemieniuk     // Match suitable generics.
247d776346aSAdam Siemieniuk     if (!linalg::isaContractionOpInterface(linalgOp)) {
2484c3db258SAdam Siemieniuk       return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
2494c3db258SAdam Siemieniuk     }
2504c3db258SAdam Siemieniuk 
2514c3db258SAdam Siemieniuk     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2524c3db258SAdam Siemieniuk     auto infer = [&](MapList m) {
2534c3db258SAdam Siemieniuk       return AffineMap::inferFromExprList(m, linalgOp.getContext());
2544c3db258SAdam Siemieniuk     };
2554c3db258SAdam Siemieniuk 
2564c3db258SAdam Siemieniuk     AffineExpr i, j, k;
2574c3db258SAdam Siemieniuk     bindDims(linalgOp->getContext(), i, j, k);
2584c3db258SAdam Siemieniuk     SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
2594c3db258SAdam Siemieniuk 
2604c3db258SAdam Siemieniuk     // For now, only match simple matmuls.
2614c3db258SAdam Siemieniuk     if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
2624c3db258SAdam Siemieniuk           maps == infer({{k, i}, {k, j}, {i, j}}) ||
2634c3db258SAdam Siemieniuk           maps == infer({{i, k}, {j, k}, {i, j}}))) {
2644c3db258SAdam Siemieniuk       return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
2654c3db258SAdam Siemieniuk     }
2664c3db258SAdam Siemieniuk 
2674c3db258SAdam Siemieniuk     FailureOr<PackResult> packedMatmul =
2684c3db258SAdam Siemieniuk         blockPackMatmul(rewriter, linalgOp, controlFn);
2694c3db258SAdam Siemieniuk     if (failed(packedMatmul))
2704c3db258SAdam Siemieniuk       return failure();
2714c3db258SAdam Siemieniuk     return success();
2724c3db258SAdam Siemieniuk   }
2734c3db258SAdam Siemieniuk 
2744c3db258SAdam Siemieniuk private:
2754c3db258SAdam Siemieniuk   ControlBlockPackMatmulFn controlFn;
2764c3db258SAdam Siemieniuk };
2774c3db258SAdam Siemieniuk 
2784c3db258SAdam Siemieniuk /// Convert linalg matmul ops to block layout and back.
2794c3db258SAdam Siemieniuk struct LinalgBlockPackMatmul
2804c3db258SAdam Siemieniuk     : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
2814c3db258SAdam Siemieniuk   using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
2824c3db258SAdam Siemieniuk 
2834c3db258SAdam Siemieniuk   void runOnOperation() override {
2844c3db258SAdam Siemieniuk     Operation *op = getOperation();
2854c3db258SAdam Siemieniuk     RewritePatternSet patterns(&getContext());
2864c3db258SAdam Siemieniuk 
2874c3db258SAdam Siemieniuk     ControlBlockPackMatmulFn controlFn =
2884c3db258SAdam Siemieniuk         [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
2894c3db258SAdam Siemieniuk       BlockPackMatmulOptions options;
2904c3db258SAdam Siemieniuk       options.blockFactors = SmallVector<int64_t>{*blockFactors};
2914c3db258SAdam Siemieniuk       options.allowPadding = allowPadding;
2924c3db258SAdam Siemieniuk       options.mnkPaddedSizesNextMultipleOf =
2934c3db258SAdam Siemieniuk           SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
2944c3db258SAdam Siemieniuk       if (!mnkOrder.empty())
2954c3db258SAdam Siemieniuk         options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
2964c3db258SAdam Siemieniuk       options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
2974c3db258SAdam Siemieniuk       options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
2984c3db258SAdam Siemieniuk       options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
2994c3db258SAdam Siemieniuk       options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
3004c3db258SAdam Siemieniuk       return options;
3014c3db258SAdam Siemieniuk     };
3024c3db258SAdam Siemieniuk 
3034c3db258SAdam Siemieniuk     linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
30409dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(patterns))))
3054c3db258SAdam Siemieniuk       return signalPassFailure();
3064c3db258SAdam Siemieniuk   }
3074c3db258SAdam Siemieniuk };
3084c3db258SAdam Siemieniuk } // namespace
3094c3db258SAdam Siemieniuk 
3104c3db258SAdam Siemieniuk void linalg::populateBlockPackMatmulPatterns(
3114c3db258SAdam Siemieniuk     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
3124c3db258SAdam Siemieniuk   patterns.add<BlockPackMatmul<linalg::GenericOp>,
3134c3db258SAdam Siemieniuk                BlockPackMatmul<linalg::MatmulOp>,
3144c3db258SAdam Siemieniuk                BlockPackMatmul<linalg::BatchMatmulOp>,
3154c3db258SAdam Siemieniuk                BlockPackMatmul<linalg::MatmulTransposeAOp>,
3164c3db258SAdam Siemieniuk                BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
3174c3db258SAdam Siemieniuk                BlockPackMatmul<linalg::MatmulTransposeBOp>,
3184c3db258SAdam Siemieniuk                BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
3194c3db258SAdam Siemieniuk       patterns.getContext(), controlFn);
3204c3db258SAdam Siemieniuk }
321