xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (revision 3ad0148020ca91cc288bffd8ad36e25f7555a3bb)
1 //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This is intended to be a simple high-level (target-agnostic) matmul
9 // transposition transformation.
10 //===----------------------------------------------------------------------===//
11 
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 #define DEBUG_TYPE "linalg-transpose-matmul"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Pattern to replace
22 ///
23 ///   linalg.matmul(a, b)
24 ///
25 /// with
26 ///
27 ///   linalg.matmul_transpose_a(linalg.transpose(a), b)
28 ///
29 /// By default the LHS is transposed. Set `transposeLHS=false` to
30 /// transpose RHS instead.
31 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32                                                      linalg::MatmulOp matmulOp,
33                                                      bool transposeLHS) {
34   // Check to not let go the matmul with extended semantic, through this
35   // transform.
36   if (matmulOp.hasUserDefinedMaps()) {
37     return rewriter.notifyMatchFailure(
38         matmulOp, "only matmul ops with non-extended semantics are supported");
39   }
40 
41   if (!bufferization::hasTensorSemantics(matmulOp))
42     return rewriter.notifyMatchFailure(
43         matmulOp, "only matmul ops with tensors are supported");
44 
45   Location loc = matmulOp.getLoc();
46   Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
47   auto type = cast<ShapedType>(input.getType());
48 
49   SmallVector<Value> dynamicDims;
50   if (type.isDynamicDim(1))
51     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
52   if (type.isDynamicDim(0))
53     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
54 
55   ArrayRef<int64_t> shape = type.getShape();
56   Value empty = rewriter.create<tensor::EmptyOp>(
57       loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
58       dynamicDims);
59   auto transposeOp = rewriter.create<linalg::TransposeOp>(
60       loc, input, empty, ArrayRef<int64_t>{1, 0});
61   Operation *newMatmulOp;
62   if (transposeLHS) {
63     newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
64         loc, matmulOp.getResultTypes(),
65         ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
66         matmulOp.getOutputs());
67   } else {
68     newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
69         loc, matmulOp.getResultTypes(),
70         ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
71         matmulOp.getOutputs());
72   }
73   rewriter.replaceOp(matmulOp, newMatmulOp);
74   return newMatmulOp;
75 }
76 
77 /// Pattern to replace
78 ///
79 ///   linalg.batch_matmul(a, b)
80 ///
81 /// with
82 ///
83 ///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84 ///
85 /// Only the non-batch dimensions are transposed. By default the LHS is
86 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
87 FailureOr<Operation *>
88 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
89                                    linalg::BatchMatmulOp batchMatmulOp,
90                                    bool transposeLHS) {
91   if (!bufferization::hasTensorSemantics(batchMatmulOp))
92     return rewriter.notifyMatchFailure(
93         batchMatmulOp, "only matmul ops with tensors are supported");
94 
95   Location loc = batchMatmulOp.getLoc();
96   Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
97   auto type = cast<ShapedType>(input.getType());
98 
99   SmallVector<Value> dynamicDims;
100   if (type.isDynamicDim(0))
101     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
102   if (type.isDynamicDim(2))
103     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
104   if (type.isDynamicDim(1))
105     dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
106 
107   ArrayRef<int64_t> shape = type.getShape();
108   Value empty = rewriter.create<tensor::EmptyOp>(
109       loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
110       type.getElementType(), dynamicDims);
111   auto transposeOp = rewriter.create<linalg::TransposeOp>(
112       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
113   Operation *newMatmulOp;
114   if (transposeLHS) {
115     newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
116         loc, batchMatmulOp.getResultTypes(),
117         ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
118         batchMatmulOp.getOutputs());
119   } else {
120     newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
121         loc, batchMatmulOp.getResultTypes(),
122         ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
123         batchMatmulOp.getOutputs());
124   }
125   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
126   return newMatmulOp;
127 }
128 
129 namespace {
130 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
131   TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
132       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
133 
134   LogicalResult matchAndRewrite(linalg::MatmulOp op,
135                                 PatternRewriter &rewriter) const override {
136     if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
137       return failure();
138     }
139     return success();
140   }
141 
142 private:
143   bool transposeLHS;
144 };
145 
146 struct TransposeBatchMatmul final
147     : public OpRewritePattern<linalg::BatchMatmulOp> {
148   TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
149       : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
150 
151   LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
152                                 PatternRewriter &rewriter) const override {
153     if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
154       return failure();
155     }
156     return success();
157   }
158 
159 private:
160   bool transposeLHS;
161 };
162 } // namespace
163 
164 void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
165                                                    bool transposeLHS) {
166   patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
167                                                       transposeLHS);
168 }
169