//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This is intended to be a simple high-level (target-agnostic) matmul // transposition transformation. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "linalg-transpose-matmul" using namespace mlir; using namespace mlir::linalg; /// Pattern to replace /// /// linalg.matmul(a, b) /// /// with /// /// linalg.matmul_transpose_a(linalg.transpose(a), b) /// /// By default the LHS is transposed. Set `transposeLHS=false` to /// transpose RHS instead. FailureOr mlir::linalg::transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp matmulOp, bool transposeLHS) { // Check to not let go the matmul with extended semantic, through this // transform. if (matmulOp.hasUserDefinedMaps()) { return rewriter.notifyMatchFailure( matmulOp, "only matmul ops with non-extended semantics are supported"); } if (!bufferization::hasTensorSemantics(matmulOp)) return rewriter.notifyMatchFailure( matmulOp, "only matmul ops with tensors are supported"); Location loc = matmulOp.getLoc(); Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1]; auto type = cast(input.getType()); SmallVector dynamicDims; if (type.isDynamicDim(1)) dynamicDims.push_back(rewriter.create(loc, input, 1)); if (type.isDynamicDim(0)) dynamicDims.push_back(rewriter.create(loc, input, 0)); ArrayRef shape = type.getShape(); Value empty = rewriter.create( loc, ArrayRef{shape[1], shape[0]}, type.getElementType(), dynamicDims); auto transposeOp = rewriter.create( loc, input, empty, ArrayRef{1, 0}); Operation *newMatmulOp; if (transposeLHS) { newMatmulOp = rewriter.create( loc, matmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, matmulOp.getOutputs()); } else { newMatmulOp = rewriter.create( loc, matmulOp.getResultTypes(), ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, matmulOp.getOutputs()); } rewriter.replaceOp(matmulOp, newMatmulOp); return newMatmulOp; } /// Pattern to replace /// /// linalg.batch_matmul(a, b) /// /// with /// /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b) /// /// Only the non-batch dimensions are transposed. By default the LHS is /// transposed. Set `transposeLHS=false` to transpose RHS instead. FailureOr mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, bool transposeLHS) { if (!bufferization::hasTensorSemantics(batchMatmulOp)) return rewriter.notifyMatchFailure( batchMatmulOp, "only matmul ops with tensors are supported"); Location loc = batchMatmulOp.getLoc(); Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1]; auto type = cast(input.getType()); SmallVector dynamicDims; if (type.isDynamicDim(0)) dynamicDims.push_back(rewriter.create(loc, input, 0)); if (type.isDynamicDim(2)) dynamicDims.push_back(rewriter.create(loc, input, 2)); if (type.isDynamicDim(1)) dynamicDims.push_back(rewriter.create(loc, input, 1)); ArrayRef shape = type.getShape(); Value empty = rewriter.create( loc, ArrayRef{shape[0], shape[2], shape[1]}, type.getElementType(), dynamicDims); auto transposeOp = rewriter.create( loc, input, empty, ArrayRef{0, 2, 1}); Operation *newMatmulOp; if (transposeLHS) { newMatmulOp = rewriter.create( loc, batchMatmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, batchMatmulOp.getOutputs()); } else { newMatmulOp = rewriter.create( loc, batchMatmulOp.getResultTypes(), ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, batchMatmulOp.getOutputs()); } rewriter.replaceOp(batchMatmulOp, newMatmulOp); return newMatmulOp; } namespace { struct TransposeMatmul final : public OpRewritePattern { TransposeMatmul(MLIRContext *ctx, bool transposeLHS) : OpRewritePattern(ctx), transposeLHS(transposeLHS) {} LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter &rewriter) const override { if (failed(transposeMatmul(rewriter, op, transposeLHS))) { return failure(); } return success(); } private: bool transposeLHS; }; struct TransposeBatchMatmul final : public OpRewritePattern { TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS) : OpRewritePattern(ctx), transposeLHS(transposeLHS) {} LogicalResult matchAndRewrite(linalg::BatchMatmulOp op, PatternRewriter &rewriter) const override { if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) { return failure(); } return success(); } private: bool transposeLHS; }; } // namespace void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS) { patterns.add(patterns.getContext(), transposeLHS); }