1 //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This is intended to be a simple high-level (target-agnostic) matmul 9 // transposition transformation. 10 //===----------------------------------------------------------------------===// 11 12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 #define DEBUG_TYPE "linalg-transpose-matmul" 17 18 using namespace mlir; 19 using namespace mlir::linalg; 20 21 /// Pattern to replace 22 /// 23 /// linalg.matmul(a, b) 24 /// 25 /// with 26 /// 27 /// linalg.matmul_transpose_a(linalg.transpose(a), b) 28 /// 29 /// By default the LHS is transposed. Set `transposeLHS=false` to 30 /// transpose RHS instead. 31 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, 32 linalg::MatmulOp matmulOp, 33 bool transposeLHS) { 34 // Check to not let go the matmul with extended semantic, through this 35 // transform. 36 if (matmulOp.hasUserDefinedMaps()) { 37 return rewriter.notifyMatchFailure( 38 matmulOp, "only matmul ops with non-extended semantics are supported"); 39 } 40 41 if (!bufferization::hasTensorSemantics(matmulOp)) 42 return rewriter.notifyMatchFailure( 43 matmulOp, "only matmul ops with tensors are supported"); 44 45 Location loc = matmulOp.getLoc(); 46 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1]; 47 auto type = cast<ShapedType>(input.getType()); 48 49 SmallVector<Value> dynamicDims; 50 if (type.isDynamicDim(1)) 51 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1)); 52 if (type.isDynamicDim(0)) 53 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); 54 55 ArrayRef<int64_t> shape = type.getShape(); 56 Value empty = rewriter.create<tensor::EmptyOp>( 57 loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(), 58 dynamicDims); 59 auto transposeOp = rewriter.create<linalg::TransposeOp>( 60 loc, input, empty, ArrayRef<int64_t>{1, 0}); 61 Operation *newMatmulOp; 62 if (transposeLHS) { 63 newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>( 64 loc, matmulOp.getResultTypes(), 65 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, 66 matmulOp.getOutputs()); 67 } else { 68 newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>( 69 loc, matmulOp.getResultTypes(), 70 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, 71 matmulOp.getOutputs()); 72 } 73 rewriter.replaceOp(matmulOp, newMatmulOp); 74 return newMatmulOp; 75 } 76 77 /// Pattern to replace 78 /// 79 /// linalg.batch_matmul(a, b) 80 /// 81 /// with 82 /// 83 /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b) 84 /// 85 /// Only the non-batch dimensions are transposed. By default the LHS is 86 /// transposed. Set `transposeLHS=false` to transpose RHS instead. 87 FailureOr<Operation *> 88 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, 89 linalg::BatchMatmulOp batchMatmulOp, 90 bool transposeLHS) { 91 if (!bufferization::hasTensorSemantics(batchMatmulOp)) 92 return rewriter.notifyMatchFailure( 93 batchMatmulOp, "only matmul ops with tensors are supported"); 94 95 Location loc = batchMatmulOp.getLoc(); 96 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1]; 97 auto type = cast<ShapedType>(input.getType()); 98 99 SmallVector<Value> dynamicDims; 100 if (type.isDynamicDim(0)) 101 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); 102 if (type.isDynamicDim(2)) 103 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2)); 104 if (type.isDynamicDim(1)) 105 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1)); 106 107 ArrayRef<int64_t> shape = type.getShape(); 108 Value empty = rewriter.create<tensor::EmptyOp>( 109 loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]}, 110 type.getElementType(), dynamicDims); 111 auto transposeOp = rewriter.create<linalg::TransposeOp>( 112 loc, input, empty, ArrayRef<int64_t>{0, 2, 1}); 113 Operation *newMatmulOp; 114 if (transposeLHS) { 115 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>( 116 loc, batchMatmulOp.getResultTypes(), 117 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, 118 batchMatmulOp.getOutputs()); 119 } else { 120 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>( 121 loc, batchMatmulOp.getResultTypes(), 122 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, 123 batchMatmulOp.getOutputs()); 124 } 125 rewriter.replaceOp(batchMatmulOp, newMatmulOp); 126 return newMatmulOp; 127 } 128 129 namespace { 130 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> { 131 TransposeMatmul(MLIRContext *ctx, bool transposeLHS) 132 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {} 133 134 LogicalResult matchAndRewrite(linalg::MatmulOp op, 135 PatternRewriter &rewriter) const override { 136 if (failed(transposeMatmul(rewriter, op, transposeLHS))) { 137 return failure(); 138 } 139 return success(); 140 } 141 142 private: 143 bool transposeLHS; 144 }; 145 146 struct TransposeBatchMatmul final 147 : public OpRewritePattern<linalg::BatchMatmulOp> { 148 TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS) 149 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {} 150 151 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op, 152 PatternRewriter &rewriter) const override { 153 if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) { 154 return failure(); 155 } 156 return success(); 157 } 158 159 private: 160 bool transposeLHS; 161 }; 162 } // namespace 163 164 void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns, 165 bool transposeLHS) { 166 patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(), 167 transposeLHS); 168 } 169