xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (revision 3ad0148020ca91cc288bffd8ad36e25f7555a3bb)
179225349SCullen Rhodes //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
279225349SCullen Rhodes //
379225349SCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
479225349SCullen Rhodes // See https://llvm.org/LICENSE.txt for license information.
579225349SCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
679225349SCullen Rhodes //
779225349SCullen Rhodes //===----------------------------------------------------------------------===//
879225349SCullen Rhodes // This is intended to be a simple high-level (target-agnostic) matmul
979225349SCullen Rhodes // transposition transformation.
1079225349SCullen Rhodes //===----------------------------------------------------------------------===//
1179225349SCullen Rhodes 
1279225349SCullen Rhodes #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1379225349SCullen Rhodes #include "mlir/IR/PatternMatch.h"
1479225349SCullen Rhodes #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1579225349SCullen Rhodes 
1679225349SCullen Rhodes #define DEBUG_TYPE "linalg-transpose-matmul"
1779225349SCullen Rhodes 
1879225349SCullen Rhodes using namespace mlir;
1979225349SCullen Rhodes using namespace mlir::linalg;
2079225349SCullen Rhodes 
2179225349SCullen Rhodes /// Pattern to replace
2279225349SCullen Rhodes ///
2379225349SCullen Rhodes ///   linalg.matmul(a, b)
2479225349SCullen Rhodes ///
2579225349SCullen Rhodes /// with
2679225349SCullen Rhodes ///
2779225349SCullen Rhodes ///   linalg.matmul_transpose_a(linalg.transpose(a), b)
2879225349SCullen Rhodes ///
2979225349SCullen Rhodes /// By default the LHS is transposed. Set `transposeLHS=false` to
3079225349SCullen Rhodes /// transpose RHS instead.
31be1c72d2SCullen Rhodes FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32be1c72d2SCullen Rhodes                                                      linalg::MatmulOp matmulOp,
33be1c72d2SCullen Rhodes                                                      bool transposeLHS) {
34*3ad01480SMd Asghar Ahmad Shahid   // Check to not let go the matmul with extended semantic, through this
35*3ad01480SMd Asghar Ahmad Shahid   // transform.
36*3ad01480SMd Asghar Ahmad Shahid   if (matmulOp.hasUserDefinedMaps()) {
37*3ad01480SMd Asghar Ahmad Shahid     return rewriter.notifyMatchFailure(
38*3ad01480SMd Asghar Ahmad Shahid         matmulOp, "only matmul ops with non-extended semantics are supported");
39*3ad01480SMd Asghar Ahmad Shahid   }
40*3ad01480SMd Asghar Ahmad Shahid 
4179225349SCullen Rhodes   if (!bufferization::hasTensorSemantics(matmulOp))
4279225349SCullen Rhodes     return rewriter.notifyMatchFailure(
4379225349SCullen Rhodes         matmulOp, "only matmul ops with tensors are supported");
4479225349SCullen Rhodes 
4579225349SCullen Rhodes   Location loc = matmulOp.getLoc();
4679225349SCullen Rhodes   Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
4779225349SCullen Rhodes   auto type = cast<ShapedType>(input.getType());
4879225349SCullen Rhodes 
4979225349SCullen Rhodes   SmallVector<Value> dynamicDims;
5079225349SCullen Rhodes   if (type.isDynamicDim(1))
5179225349SCullen Rhodes     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
5279225349SCullen Rhodes   if (type.isDynamicDim(0))
5379225349SCullen Rhodes     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
5479225349SCullen Rhodes 
5579225349SCullen Rhodes   ArrayRef<int64_t> shape = type.getShape();
5679225349SCullen Rhodes   Value empty = rewriter.create<tensor::EmptyOp>(
5779225349SCullen Rhodes       loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
5879225349SCullen Rhodes       dynamicDims);
5979225349SCullen Rhodes   auto transposeOp = rewriter.create<linalg::TransposeOp>(
6079225349SCullen Rhodes       loc, input, empty, ArrayRef<int64_t>{1, 0});
61be1c72d2SCullen Rhodes   Operation *newMatmulOp;
6279225349SCullen Rhodes   if (transposeLHS) {
63be1c72d2SCullen Rhodes     newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
64be1c72d2SCullen Rhodes         loc, matmulOp.getResultTypes(),
6579225349SCullen Rhodes         ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
6679225349SCullen Rhodes         matmulOp.getOutputs());
6779225349SCullen Rhodes   } else {
68be1c72d2SCullen Rhodes     newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
69be1c72d2SCullen Rhodes         loc, matmulOp.getResultTypes(),
7079225349SCullen Rhodes         ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
7179225349SCullen Rhodes         matmulOp.getOutputs());
7279225349SCullen Rhodes   }
73be1c72d2SCullen Rhodes   rewriter.replaceOp(matmulOp, newMatmulOp);
74be1c72d2SCullen Rhodes   return newMatmulOp;
7579225349SCullen Rhodes }
7679225349SCullen Rhodes 
7779225349SCullen Rhodes /// Pattern to replace
7879225349SCullen Rhodes ///
7979225349SCullen Rhodes ///   linalg.batch_matmul(a, b)
8079225349SCullen Rhodes ///
8179225349SCullen Rhodes /// with
8279225349SCullen Rhodes ///
8379225349SCullen Rhodes ///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
8479225349SCullen Rhodes ///
8579225349SCullen Rhodes /// Only the non-batch dimensions are transposed. By default the LHS is
8679225349SCullen Rhodes /// transposed. Set `transposeLHS=false` to transpose RHS instead.
87be1c72d2SCullen Rhodes FailureOr<Operation *>
88be1c72d2SCullen Rhodes mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
89be1c72d2SCullen Rhodes                                    linalg::BatchMatmulOp batchMatmulOp,
90be1c72d2SCullen Rhodes                                    bool transposeLHS) {
9179225349SCullen Rhodes   if (!bufferization::hasTensorSemantics(batchMatmulOp))
9279225349SCullen Rhodes     return rewriter.notifyMatchFailure(
9379225349SCullen Rhodes         batchMatmulOp, "only matmul ops with tensors are supported");
9479225349SCullen Rhodes 
9579225349SCullen Rhodes   Location loc = batchMatmulOp.getLoc();
9679225349SCullen Rhodes   Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
9779225349SCullen Rhodes   auto type = cast<ShapedType>(input.getType());
9879225349SCullen Rhodes 
9979225349SCullen Rhodes   SmallVector<Value> dynamicDims;
10079225349SCullen Rhodes   if (type.isDynamicDim(0))
10179225349SCullen Rhodes     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
10279225349SCullen Rhodes   if (type.isDynamicDim(2))
10379225349SCullen Rhodes     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
10479225349SCullen Rhodes   if (type.isDynamicDim(1))
10579225349SCullen Rhodes     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
10679225349SCullen Rhodes 
10779225349SCullen Rhodes   ArrayRef<int64_t> shape = type.getShape();
10879225349SCullen Rhodes   Value empty = rewriter.create<tensor::EmptyOp>(
10979225349SCullen Rhodes       loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
11079225349SCullen Rhodes       type.getElementType(), dynamicDims);
11179225349SCullen Rhodes   auto transposeOp = rewriter.create<linalg::TransposeOp>(
11279225349SCullen Rhodes       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
113be1c72d2SCullen Rhodes   Operation *newMatmulOp;
11479225349SCullen Rhodes   if (transposeLHS) {
115be1c72d2SCullen Rhodes     newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
116be1c72d2SCullen Rhodes         loc, batchMatmulOp.getResultTypes(),
11779225349SCullen Rhodes         ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
11879225349SCullen Rhodes         batchMatmulOp.getOutputs());
11979225349SCullen Rhodes   } else {
120be1c72d2SCullen Rhodes     newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
121be1c72d2SCullen Rhodes         loc, batchMatmulOp.getResultTypes(),
12279225349SCullen Rhodes         ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
12379225349SCullen Rhodes         batchMatmulOp.getOutputs());
12479225349SCullen Rhodes   }
125be1c72d2SCullen Rhodes   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
126be1c72d2SCullen Rhodes   return newMatmulOp;
127be1c72d2SCullen Rhodes }
12879225349SCullen Rhodes 
129be1c72d2SCullen Rhodes namespace {
130be1c72d2SCullen Rhodes struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
131be1c72d2SCullen Rhodes   TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
132be1c72d2SCullen Rhodes       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
133be1c72d2SCullen Rhodes 
134be1c72d2SCullen Rhodes   LogicalResult matchAndRewrite(linalg::MatmulOp op,
135be1c72d2SCullen Rhodes                                 PatternRewriter &rewriter) const override {
136be1c72d2SCullen Rhodes     if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
137be1c72d2SCullen Rhodes       return failure();
138be1c72d2SCullen Rhodes     }
139be1c72d2SCullen Rhodes     return success();
140be1c72d2SCullen Rhodes   }
141be1c72d2SCullen Rhodes 
142be1c72d2SCullen Rhodes private:
143be1c72d2SCullen Rhodes   bool transposeLHS;
144be1c72d2SCullen Rhodes };
145be1c72d2SCullen Rhodes 
146be1c72d2SCullen Rhodes struct TransposeBatchMatmul final
147be1c72d2SCullen Rhodes     : public OpRewritePattern<linalg::BatchMatmulOp> {
148be1c72d2SCullen Rhodes   TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
149be1c72d2SCullen Rhodes       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
150be1c72d2SCullen Rhodes 
151be1c72d2SCullen Rhodes   LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
152be1c72d2SCullen Rhodes                                 PatternRewriter &rewriter) const override {
153be1c72d2SCullen Rhodes     if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
154be1c72d2SCullen Rhodes       return failure();
155be1c72d2SCullen Rhodes     }
15679225349SCullen Rhodes     return success();
15779225349SCullen Rhodes   }
15879225349SCullen Rhodes 
15979225349SCullen Rhodes private:
16079225349SCullen Rhodes   bool transposeLHS;
16179225349SCullen Rhodes };
16279225349SCullen Rhodes } // namespace
16379225349SCullen Rhodes 
16479225349SCullen Rhodes void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16579225349SCullen Rhodes                                                    bool transposeLHS) {
16679225349SCullen Rhodes   patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
16779225349SCullen Rhodes                                                       transposeLHS);
16879225349SCullen Rhodes }
169