179225349SCullen Rhodes //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===// 279225349SCullen Rhodes // 379225349SCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 479225349SCullen Rhodes // See https://llvm.org/LICENSE.txt for license information. 579225349SCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 679225349SCullen Rhodes // 779225349SCullen Rhodes //===----------------------------------------------------------------------===// 879225349SCullen Rhodes // This is intended to be a simple high-level (target-agnostic) matmul 979225349SCullen Rhodes // transposition transformation. 1079225349SCullen Rhodes //===----------------------------------------------------------------------===// 1179225349SCullen Rhodes 1279225349SCullen Rhodes #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 1379225349SCullen Rhodes #include "mlir/IR/PatternMatch.h" 1479225349SCullen Rhodes #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1579225349SCullen Rhodes 1679225349SCullen Rhodes #define DEBUG_TYPE "linalg-transpose-matmul" 1779225349SCullen Rhodes 1879225349SCullen Rhodes using namespace mlir; 1979225349SCullen Rhodes using namespace mlir::linalg; 2079225349SCullen Rhodes 2179225349SCullen Rhodes /// Pattern to replace 2279225349SCullen Rhodes /// 2379225349SCullen Rhodes /// linalg.matmul(a, b) 2479225349SCullen Rhodes /// 2579225349SCullen Rhodes /// with 2679225349SCullen Rhodes /// 2779225349SCullen Rhodes /// linalg.matmul_transpose_a(linalg.transpose(a), b) 2879225349SCullen Rhodes /// 2979225349SCullen Rhodes /// By default the LHS is transposed. Set `transposeLHS=false` to 3079225349SCullen Rhodes /// transpose RHS instead. 31be1c72d2SCullen Rhodes FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, 32be1c72d2SCullen Rhodes linalg::MatmulOp matmulOp, 33be1c72d2SCullen Rhodes bool transposeLHS) { 34*3ad01480SMd Asghar Ahmad Shahid // Check to not let go the matmul with extended semantic, through this 35*3ad01480SMd Asghar Ahmad Shahid // transform. 36*3ad01480SMd Asghar Ahmad Shahid if (matmulOp.hasUserDefinedMaps()) { 37*3ad01480SMd Asghar Ahmad Shahid return rewriter.notifyMatchFailure( 38*3ad01480SMd Asghar Ahmad Shahid matmulOp, "only matmul ops with non-extended semantics are supported"); 39*3ad01480SMd Asghar Ahmad Shahid } 40*3ad01480SMd Asghar Ahmad Shahid 4179225349SCullen Rhodes if (!bufferization::hasTensorSemantics(matmulOp)) 4279225349SCullen Rhodes return rewriter.notifyMatchFailure( 4379225349SCullen Rhodes matmulOp, "only matmul ops with tensors are supported"); 4479225349SCullen Rhodes 4579225349SCullen Rhodes Location loc = matmulOp.getLoc(); 4679225349SCullen Rhodes Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1]; 4779225349SCullen Rhodes auto type = cast<ShapedType>(input.getType()); 4879225349SCullen Rhodes 4979225349SCullen Rhodes SmallVector<Value> dynamicDims; 5079225349SCullen Rhodes if (type.isDynamicDim(1)) 5179225349SCullen Rhodes dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1)); 5279225349SCullen Rhodes if (type.isDynamicDim(0)) 5379225349SCullen Rhodes dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); 5479225349SCullen Rhodes 5579225349SCullen Rhodes ArrayRef<int64_t> shape = type.getShape(); 5679225349SCullen Rhodes Value empty = rewriter.create<tensor::EmptyOp>( 5779225349SCullen Rhodes loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(), 5879225349SCullen Rhodes dynamicDims); 5979225349SCullen Rhodes auto transposeOp = rewriter.create<linalg::TransposeOp>( 6079225349SCullen Rhodes loc, input, empty, ArrayRef<int64_t>{1, 0}); 61be1c72d2SCullen Rhodes Operation *newMatmulOp; 6279225349SCullen Rhodes if (transposeLHS) { 63be1c72d2SCullen Rhodes newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>( 64be1c72d2SCullen Rhodes loc, matmulOp.getResultTypes(), 6579225349SCullen Rhodes ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, 6679225349SCullen Rhodes matmulOp.getOutputs()); 6779225349SCullen Rhodes } else { 68be1c72d2SCullen Rhodes newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>( 69be1c72d2SCullen Rhodes loc, matmulOp.getResultTypes(), 7079225349SCullen Rhodes ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, 7179225349SCullen Rhodes matmulOp.getOutputs()); 7279225349SCullen Rhodes } 73be1c72d2SCullen Rhodes rewriter.replaceOp(matmulOp, newMatmulOp); 74be1c72d2SCullen Rhodes return newMatmulOp; 7579225349SCullen Rhodes } 7679225349SCullen Rhodes 7779225349SCullen Rhodes /// Pattern to replace 7879225349SCullen Rhodes /// 7979225349SCullen Rhodes /// linalg.batch_matmul(a, b) 8079225349SCullen Rhodes /// 8179225349SCullen Rhodes /// with 8279225349SCullen Rhodes /// 8379225349SCullen Rhodes /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b) 8479225349SCullen Rhodes /// 8579225349SCullen Rhodes /// Only the non-batch dimensions are transposed. By default the LHS is 8679225349SCullen Rhodes /// transposed. Set `transposeLHS=false` to transpose RHS instead. 87be1c72d2SCullen Rhodes FailureOr<Operation *> 88be1c72d2SCullen Rhodes mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, 89be1c72d2SCullen Rhodes linalg::BatchMatmulOp batchMatmulOp, 90be1c72d2SCullen Rhodes bool transposeLHS) { 9179225349SCullen Rhodes if (!bufferization::hasTensorSemantics(batchMatmulOp)) 9279225349SCullen Rhodes return rewriter.notifyMatchFailure( 9379225349SCullen Rhodes batchMatmulOp, "only matmul ops with tensors are supported"); 9479225349SCullen Rhodes 9579225349SCullen Rhodes Location loc = batchMatmulOp.getLoc(); 9679225349SCullen Rhodes Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1]; 9779225349SCullen Rhodes auto type = cast<ShapedType>(input.getType()); 9879225349SCullen Rhodes 9979225349SCullen Rhodes SmallVector<Value> dynamicDims; 10079225349SCullen Rhodes if (type.isDynamicDim(0)) 10179225349SCullen Rhodes dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); 10279225349SCullen Rhodes if (type.isDynamicDim(2)) 10379225349SCullen Rhodes dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2)); 10479225349SCullen Rhodes if (type.isDynamicDim(1)) 10579225349SCullen Rhodes dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1)); 10679225349SCullen Rhodes 10779225349SCullen Rhodes ArrayRef<int64_t> shape = type.getShape(); 10879225349SCullen Rhodes Value empty = rewriter.create<tensor::EmptyOp>( 10979225349SCullen Rhodes loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]}, 11079225349SCullen Rhodes type.getElementType(), dynamicDims); 11179225349SCullen Rhodes auto transposeOp = rewriter.create<linalg::TransposeOp>( 11279225349SCullen Rhodes loc, input, empty, ArrayRef<int64_t>{0, 2, 1}); 113be1c72d2SCullen Rhodes Operation *newMatmulOp; 11479225349SCullen Rhodes if (transposeLHS) { 115be1c72d2SCullen Rhodes newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>( 116be1c72d2SCullen Rhodes loc, batchMatmulOp.getResultTypes(), 11779225349SCullen Rhodes ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, 11879225349SCullen Rhodes batchMatmulOp.getOutputs()); 11979225349SCullen Rhodes } else { 120be1c72d2SCullen Rhodes newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>( 121be1c72d2SCullen Rhodes loc, batchMatmulOp.getResultTypes(), 12279225349SCullen Rhodes ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, 12379225349SCullen Rhodes batchMatmulOp.getOutputs()); 12479225349SCullen Rhodes } 125be1c72d2SCullen Rhodes rewriter.replaceOp(batchMatmulOp, newMatmulOp); 126be1c72d2SCullen Rhodes return newMatmulOp; 127be1c72d2SCullen Rhodes } 12879225349SCullen Rhodes 129be1c72d2SCullen Rhodes namespace { 130be1c72d2SCullen Rhodes struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> { 131be1c72d2SCullen Rhodes TransposeMatmul(MLIRContext *ctx, bool transposeLHS) 132be1c72d2SCullen Rhodes : OpRewritePattern(ctx), transposeLHS(transposeLHS) {} 133be1c72d2SCullen Rhodes 134be1c72d2SCullen Rhodes LogicalResult matchAndRewrite(linalg::MatmulOp op, 135be1c72d2SCullen Rhodes PatternRewriter &rewriter) const override { 136be1c72d2SCullen Rhodes if (failed(transposeMatmul(rewriter, op, transposeLHS))) { 137be1c72d2SCullen Rhodes return failure(); 138be1c72d2SCullen Rhodes } 139be1c72d2SCullen Rhodes return success(); 140be1c72d2SCullen Rhodes } 141be1c72d2SCullen Rhodes 142be1c72d2SCullen Rhodes private: 143be1c72d2SCullen Rhodes bool transposeLHS; 144be1c72d2SCullen Rhodes }; 145be1c72d2SCullen Rhodes 146be1c72d2SCullen Rhodes struct TransposeBatchMatmul final 147be1c72d2SCullen Rhodes : public OpRewritePattern<linalg::BatchMatmulOp> { 148be1c72d2SCullen Rhodes TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS) 149be1c72d2SCullen Rhodes : OpRewritePattern(ctx), transposeLHS(transposeLHS) {} 150be1c72d2SCullen Rhodes 151be1c72d2SCullen Rhodes LogicalResult matchAndRewrite(linalg::BatchMatmulOp op, 152be1c72d2SCullen Rhodes PatternRewriter &rewriter) const override { 153be1c72d2SCullen Rhodes if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) { 154be1c72d2SCullen Rhodes return failure(); 155be1c72d2SCullen Rhodes } 15679225349SCullen Rhodes return success(); 15779225349SCullen Rhodes } 15879225349SCullen Rhodes 15979225349SCullen Rhodes private: 16079225349SCullen Rhodes bool transposeLHS; 16179225349SCullen Rhodes }; 16279225349SCullen Rhodes } // namespace 16379225349SCullen Rhodes 16479225349SCullen Rhodes void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns, 16579225349SCullen Rhodes bool transposeLHS) { 16679225349SCullen Rhodes patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(), 16779225349SCullen Rhodes transposeLHS); 16879225349SCullen Rhodes } 169