xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
13bcaf2ebSGeorgios Pinitas //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
254eec7caSRob Suderman //
354eec7caSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
454eec7caSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
554eec7caSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
654eec7caSRob Suderman //
754eec7caSRob Suderman //===----------------------------------------------------------------------===//
854eec7caSRob Suderman //
9dfd07082SAaron DeBattista // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10dfd07082SAaron DeBattista // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11dfd07082SAaron DeBattista // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12dfd07082SAaron DeBattista // including transposing/reversing/reshaping etc..
13dfd07082SAaron DeBattista //     of the weights and input/output tenors and reversing/reshaping etc .. of
14dfd07082SAaron DeBattista //     the weights
1554eec7caSRob Suderman //
1654eec7caSRob Suderman //===----------------------------------------------------------------------===//
1754eec7caSRob Suderman 
18dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/IR/TosaOps.h"
1954eec7caSRob Suderman #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20e0537d1aSTai Ly #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2154eec7caSRob Suderman #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
2254eec7caSRob Suderman #include "mlir/Pass/Pass.h"
2354eec7caSRob Suderman 
2454eec7caSRob Suderman using namespace mlir;
2554eec7caSRob Suderman using namespace mlir::tosa;
2654eec7caSRob Suderman 
2754eec7caSRob Suderman namespace {
2854eec7caSRob Suderman 
29b7f4335dSEric Kunze class TransposeConvNonStridedConverter
3054eec7caSRob Suderman     : public OpRewritePattern<tosa::TransposeConv2DOp> {
3154eec7caSRob Suderman public:
3254eec7caSRob Suderman   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
3354eec7caSRob Suderman   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
3454eec7caSRob Suderman                                 PatternRewriter &rewriter) const final {
3554eec7caSRob Suderman     Location loc = op->getLoc();
3654eec7caSRob Suderman     Value input = op->getOperand(0);
3754eec7caSRob Suderman     Value weight = op->getOperand(1);
3854eec7caSRob Suderman     Value bias = op->getOperand(2);
3954eec7caSRob Suderman 
405550c821STres Popp     ShapedType inputTy = cast<ShapedType>(input.getType());
415550c821STres Popp     ShapedType weightTy = cast<ShapedType>(weight.getType());
425550c821STres Popp     ShapedType biasTy = cast<ShapedType>(bias.getType());
435550c821STres Popp     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
4454eec7caSRob Suderman 
4511030c7dSAlexander Shaposhnikov     llvm::ArrayRef<int64_t> stride = op.getStride();
4611030c7dSAlexander Shaposhnikov     llvm::ArrayRef<int64_t> pad = op.getOutPad();
4754eec7caSRob Suderman 
4854eec7caSRob Suderman     // If striding is all 1 we can modify padding and reverse the kernel along
4954eec7caSRob Suderman     // the x/y direction to make it a regular convolution. This is much simpler
5054eec7caSRob Suderman     // then handling striding....
5154eec7caSRob Suderman     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
5254eec7caSRob Suderman       return failure();
5354eec7caSRob Suderman 
5454eec7caSRob Suderman     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
5554eec7caSRob Suderman         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
5654eec7caSRob Suderman       return failure();
5754eec7caSRob Suderman 
58b7f4335dSEric Kunze     int64_t kernelHeight = weightTy.getDimSize(1);
59b7f4335dSEric Kunze     int64_t kernelWidth = weightTy.getDimSize(2);
6054eec7caSRob Suderman 
6154eec7caSRob Suderman     llvm::SmallVector<int64_t> convPad(4, 0);
62fcbf3fafSRob Suderman     convPad[0] = kernelHeight - 1 + pad[0];
63fcbf3fafSRob Suderman     convPad[1] = kernelHeight - 1 + pad[1];
64fcbf3fafSRob Suderman     convPad[2] = kernelWidth - 1 + pad[2];
65fcbf3fafSRob Suderman     convPad[3] = kernelWidth - 1 + pad[3];
6654eec7caSRob Suderman 
6754eec7caSRob Suderman     auto reverse1 = rewriter.create<tosa::ReverseOp>(
68fc0d4d53STai Ly         loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
6954eec7caSRob Suderman     auto reverse2 = rewriter.create<tosa::ReverseOp>(
70fc0d4d53STai Ly         loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
7154eec7caSRob Suderman 
7254eec7caSRob Suderman     Value conv2d;
7313448db0SJacques Pienaar     if (op.getQuantizationInfo()) {
7454eec7caSRob Suderman       conv2d = rewriter.create<tosa::Conv2DOp>(
7554eec7caSRob Suderman           loc, resultTy, input, reverse2, bias,
7611030c7dSAlexander Shaposhnikov           rewriter.getDenseI64ArrayAttr(convPad),
7711030c7dSAlexander Shaposhnikov           rewriter.getDenseI64ArrayAttr(stride),
78360a03c9SJack Frankland           rewriter.getDenseI64ArrayAttr({1, 1}),
79360a03c9SJack Frankland           /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
8054eec7caSRob Suderman     } else {
8154eec7caSRob Suderman       conv2d = rewriter.create<tosa::Conv2DOp>(
8254eec7caSRob Suderman           loc, resultTy, input, reverse2, bias,
8311030c7dSAlexander Shaposhnikov           rewriter.getDenseI64ArrayAttr(convPad),
8411030c7dSAlexander Shaposhnikov           rewriter.getDenseI64ArrayAttr(stride),
85360a03c9SJack Frankland           rewriter.getDenseI64ArrayAttr({1, 1}),
86360a03c9SJack Frankland           /* acc_type = */ op.getAccTypeAttr());
8754eec7caSRob Suderman     }
8854eec7caSRob Suderman 
8954eec7caSRob Suderman     rewriter.replaceOp(op, conv2d);
9054eec7caSRob Suderman     return success();
9154eec7caSRob Suderman   }
9254eec7caSRob Suderman };
9354eec7caSRob Suderman 
9454eec7caSRob Suderman class TransposeConvStridedConverter
9554eec7caSRob Suderman     : public OpRewritePattern<tosa::TransposeConv2DOp> {
9654eec7caSRob Suderman public:
9754eec7caSRob Suderman   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
9854eec7caSRob Suderman   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
9954eec7caSRob Suderman                                 PatternRewriter &rewriter) const final {
10054eec7caSRob Suderman     Location loc = op->getLoc();
10154eec7caSRob Suderman     Value input = op->getOperand(0);
10254eec7caSRob Suderman     Value weight = op->getOperand(1);
10354eec7caSRob Suderman     Value bias = op->getOperand(2);
10454eec7caSRob Suderman 
1055550c821STres Popp     ShapedType inputTy = cast<ShapedType>(input.getType());
1065550c821STres Popp     ShapedType weightTy = cast<ShapedType>(weight.getType());
1075550c821STres Popp     ShapedType biasTy = cast<ShapedType>(bias.getType());
1085550c821STres Popp     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
10954eec7caSRob Suderman 
11054eec7caSRob Suderman     Type inputETy = inputTy.getElementType();
11154eec7caSRob Suderman     Type weightETy = weightTy.getElementType();
11254eec7caSRob Suderman     Type biasETy = biasTy.getElementType();
11354eec7caSRob Suderman     Type resultETy = resultTy.getElementType();
11454eec7caSRob Suderman 
11511030c7dSAlexander Shaposhnikov     llvm::ArrayRef<int64_t> pad = op.getOutPad();
11611030c7dSAlexander Shaposhnikov     llvm::ArrayRef<int64_t> stride = op.getStride();
11754eec7caSRob Suderman 
11854eec7caSRob Suderman     // If striding is all 1 we can modify padding and reverse the kernel along
11954eec7caSRob Suderman     // the x/y direction to make it a regular convolution. This is much simpler
12054eec7caSRob Suderman     // then handling striding....
12154eec7caSRob Suderman 
12254eec7caSRob Suderman     // If strides are all 1 we dont need to use this one.
12354eec7caSRob Suderman     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
124fcbf3fafSRob Suderman       return rewriter.notifyMatchFailure(op, "non-one stride found.");
12554eec7caSRob Suderman 
12654eec7caSRob Suderman     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
12754eec7caSRob Suderman         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
12854eec7caSRob Suderman       return failure();
12954eec7caSRob Suderman 
13054eec7caSRob Suderman     int64_t batch = inputTy.getDimSize(0);
13154eec7caSRob Suderman 
13254eec7caSRob Suderman     int64_t outputChannels = weightTy.getDimSize(0);
13354eec7caSRob Suderman     int64_t weightHeight = weightTy.getDimSize(1);
13454eec7caSRob Suderman     int64_t weightWidth = weightTy.getDimSize(2);
13554eec7caSRob Suderman     int64_t inputChannels = weightTy.getDimSize(3);
13654eec7caSRob Suderman 
13754eec7caSRob Suderman     // Pad the weight so that it is modulo of the striding.
1387e622b61SJerry-Ge     llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
13954eec7caSRob Suderman     weightPadding[3] =
140a98a6e95Sluolent         (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
14154eec7caSRob Suderman     weightPadding[5] =
1427e622b61SJerry-Ge         weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
1437e622b61SJerry-Ge 
1447e622b61SJerry-Ge     Value weightPaddingVal =
1457e622b61SJerry-Ge         getTosaConstShape(rewriter, op->getLoc(), weightPadding);
14654eec7caSRob Suderman 
14713448db0SJacques Pienaar     if (op.getQuantizationInfo().has_value()) {
14813448db0SJacques Pienaar       auto quantInfo = op.getQuantizationInfo().value();
149c8834527STai Ly       weight = CreateOpAndInferShape<tosa::PadOp>(
15054eec7caSRob Suderman           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
15154eec7caSRob Suderman           weightPaddingVal, nullptr,
152baca1c1aSMogball           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
15354eec7caSRob Suderman 
15454eec7caSRob Suderman     } else {
155c8834527STai Ly       weight = CreateOpAndInferShape<tosa::PadOp>(
156c8834527STai Ly           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
157c8834527STai Ly           weightPaddingVal);
15854eec7caSRob Suderman     }
15954eec7caSRob Suderman 
1605550c821STres Popp     weightTy = cast<ShapedType>(weight.getType());
16154eec7caSRob Suderman     weightHeight = weightTy.getDimSize(1);
16254eec7caSRob Suderman     weightWidth = weightTy.getDimSize(2);
16354eec7caSRob Suderman 
16454eec7caSRob Suderman     // Split out the width / height by the stride dimensions.
16554eec7caSRob Suderman     llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
16654eec7caSRob Suderman         outputChannels, weightHeight / stride[0],
16754eec7caSRob Suderman         stride[0],      weightWidth / stride[1],
16854eec7caSRob Suderman         stride[1],      inputChannels};
169c8834527STai Ly     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
17054eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
1719e1a3441SAlexander Shaposhnikov         rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
17254eec7caSRob Suderman 
17354eec7caSRob Suderman     // Transpose the factored-out stride to the output channels.
17454eec7caSRob Suderman     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
17554eec7caSRob Suderman         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
17654eec7caSRob Suderman         rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
17754eec7caSRob Suderman 
178c8834527STai Ly     weight = CreateOpAndInferShape<tosa::TransposeOp>(
17954eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
18054eec7caSRob Suderman         transposeWeightVal);
18154eec7caSRob Suderman 
18254eec7caSRob Suderman     // Collapse the strides and output channels into a single dimension.
18354eec7caSRob Suderman     llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
18454eec7caSRob Suderman         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
18554eec7caSRob Suderman         weightWidth / stride[1], inputChannels};
186c8834527STai Ly     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
18754eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
1889e1a3441SAlexander Shaposhnikov         rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
1895550c821STres Popp     ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
19054eec7caSRob Suderman 
191c8834527STai Ly     weight = CreateOpAndInferShape<tosa::ReverseOp>(
19254eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
193fc0d4d53STai Ly         /* axis = */ rewriter.getI32IntegerAttr(1));
194c8834527STai Ly     weight = CreateOpAndInferShape<tosa::ReverseOp>(
19554eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
196fc0d4d53STai Ly         /* axis = */ rewriter.getI32IntegerAttr(2));
19754eec7caSRob Suderman 
19854eec7caSRob Suderman     // We need to pad the input far enough that we can pull all values.
1997e622b61SJerry-Ge     llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
20054eec7caSRob Suderman     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
20154eec7caSRob Suderman     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
20254eec7caSRob Suderman     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
20354eec7caSRob Suderman     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
20454eec7caSRob Suderman 
2057e622b61SJerry-Ge     Value inputPaddingVal =
2067e622b61SJerry-Ge         getTosaConstShape(rewriter, op->getLoc(), inputPadding);
20754eec7caSRob Suderman 
20813448db0SJacques Pienaar     if (op.getQuantizationInfo().has_value()) {
20913448db0SJacques Pienaar       auto quantInfo = op.getQuantizationInfo().value();
210c8834527STai Ly       input = CreateOpAndInferShape<tosa::PadOp>(
21154eec7caSRob Suderman           rewriter, loc, UnrankedTensorType::get(inputETy), input,
21254eec7caSRob Suderman           inputPaddingVal, nullptr,
213baca1c1aSMogball           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
21454eec7caSRob Suderman     } else {
215c8834527STai Ly       input = CreateOpAndInferShape<tosa::PadOp>(
216c8834527STai Ly           rewriter, loc, UnrankedTensorType::get(inputETy), input,
217c8834527STai Ly           inputPaddingVal);
21854eec7caSRob Suderman     }
21954eec7caSRob Suderman 
22054eec7caSRob Suderman     // We use a zero bias as we need to broadcast the bias.
22154eec7caSRob Suderman     auto zeroBias = rewriter.create<tosa::ConstOp>(
22254eec7caSRob Suderman         loc,
22354eec7caSRob Suderman         RankedTensorType::get({outputChannels * stride[0] * stride[1]},
22454eec7caSRob Suderman                               biasETy),
22554eec7caSRob Suderman         DenseElementsAttr::get(
22654eec7caSRob Suderman             RankedTensorType::get({outputChannels * stride[0] * stride[1]},
22754eec7caSRob Suderman                                   biasETy),
22854eec7caSRob Suderman             rewriter.getZeroAttr(biasETy)));
22954eec7caSRob Suderman 
23054eec7caSRob Suderman     // Perform the convolution using the zero bias.
23154eec7caSRob Suderman     Value conv2d;
23213448db0SJacques Pienaar     if (op.getQuantizationInfo()) {
233c8834527STai Ly       conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
23454eec7caSRob Suderman                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
23554eec7caSRob Suderman                    weight, zeroBias,
23611030c7dSAlexander Shaposhnikov                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
23711030c7dSAlexander Shaposhnikov                    /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
23811030c7dSAlexander Shaposhnikov                    /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
239360a03c9SJack Frankland                    /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
24054eec7caSRob Suderman                    .getResult();
24154eec7caSRob Suderman     } else {
242c8834527STai Ly       conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
24354eec7caSRob Suderman                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
24454eec7caSRob Suderman                    weight, zeroBias,
24511030c7dSAlexander Shaposhnikov                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
24611030c7dSAlexander Shaposhnikov                    /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
247360a03c9SJack Frankland                    /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
248360a03c9SJack Frankland                    /* acc_type = */ op.getAccTypeAttr())
24954eec7caSRob Suderman                    .getResult();
25054eec7caSRob Suderman     }
25154eec7caSRob Suderman 
25254eec7caSRob Suderman     // Factor the resulting width / height.
2535550c821STres Popp     ShapedType convTy = cast<ShapedType>(conv2d.getType());
25454eec7caSRob Suderman     Type convETy = convTy.getElementType();
25554eec7caSRob Suderman 
25654eec7caSRob Suderman     int64_t convHeight = convTy.getDimSize(1);
25754eec7caSRob Suderman     int64_t convWidth = convTy.getDimSize(2);
25854eec7caSRob Suderman 
25954eec7caSRob Suderman     // Factor striding out of the convolution result.
26054eec7caSRob Suderman     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
26154eec7caSRob Suderman         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
262c8834527STai Ly     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
26354eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
2649e1a3441SAlexander Shaposhnikov         rewriter.getDenseI64ArrayAttr(convReshapeDims0));
26554eec7caSRob Suderman 
26654eec7caSRob Suderman     // Transpose the factored-out stride to the output channels.
26754eec7caSRob Suderman     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
26854eec7caSRob Suderman         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
26954eec7caSRob Suderman         rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
27054eec7caSRob Suderman 
271c8834527STai Ly     conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
27254eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
27354eec7caSRob Suderman         transposeConvVal);
27454eec7caSRob Suderman 
27554eec7caSRob Suderman     // Fuse striding behavior back into width / height.
27654eec7caSRob Suderman     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
27754eec7caSRob Suderman         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
278c8834527STai Ly     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
27954eec7caSRob Suderman         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
2809e1a3441SAlexander Shaposhnikov         rewriter.getDenseI64ArrayAttr(convReshapeDims1));
28154eec7caSRob Suderman 
282fcbf3fafSRob Suderman     // Determine the amount to slice / pad from the result start.
283fcbf3fafSRob Suderman     int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
284fcbf3fafSRob Suderman     int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
285fcbf3fafSRob Suderman     int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
286fcbf3fafSRob Suderman     int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
287fcbf3fafSRob Suderman 
288fcbf3fafSRob Suderman     // Try to slice the targetted result size, cap to the convolutions width.
289fcbf3fafSRob Suderman     int64_t resultSliceHeight =
290fcbf3fafSRob Suderman         std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
291fcbf3fafSRob Suderman                           resultTy.getDimSize(1) - resultPadTop);
292fcbf3fafSRob Suderman     int64_t resultSliceWidth =
293fcbf3fafSRob Suderman         std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
294fcbf3fafSRob Suderman                           resultTy.getDimSize(2) - resultPadLeft);
295fcbf3fafSRob Suderman 
296fcbf3fafSRob Suderman     llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
297fcbf3fafSRob Suderman                                                 resultSliceLeft, 0};
298fcbf3fafSRob Suderman     llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
299fcbf3fafSRob Suderman                                             convReshapeDims1.end());
300fcbf3fafSRob Suderman     sliceSize[1] = resultSliceHeight;
301fcbf3fafSRob Suderman     sliceSize[2] = resultSliceWidth;
30254eec7caSRob Suderman 
303c8834527STai Ly     auto slice = CreateOpAndInferShape<tosa::SliceOp>(
30454eec7caSRob Suderman                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
305*956c0707SJerry-Ge                      getTosaConstShape(rewriter, loc, sliceBegin),
306*956c0707SJerry-Ge                      getTosaConstShape(rewriter, loc, sliceSize))
30754eec7caSRob Suderman                      .getResult();
30854eec7caSRob Suderman 
3097e622b61SJerry-Ge     llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
310fcbf3fafSRob Suderman     resultPadding[2] = resultPadTop;
311fcbf3fafSRob Suderman     resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
312fcbf3fafSRob Suderman     resultPadding[4] = resultPadLeft;
313fcbf3fafSRob Suderman     resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
31454eec7caSRob Suderman 
3157e622b61SJerry-Ge     Value resultPaddingVal =
3167e622b61SJerry-Ge         getTosaConstShape(rewriter, op->getLoc(), resultPadding);
317fcbf3fafSRob Suderman 
318c8834527STai Ly     Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
319fcbf3fafSRob Suderman         rewriter, loc, UnrankedTensorType::get(resultETy), slice,
320fcbf3fafSRob Suderman         resultPaddingVal);
321fcbf3fafSRob Suderman 
322e0537d1aSTai Ly     if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
323e0537d1aSTai Ly       return failure();
324e0537d1aSTai Ly     }
325e0537d1aSTai Ly 
326fcbf3fafSRob Suderman     rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
32754eec7caSRob Suderman     return success();
32854eec7caSRob Suderman   }
32954eec7caSRob Suderman };
33054eec7caSRob Suderman 
331be0a7e9fSMehdi Amini } // namespace
33254eec7caSRob Suderman 
333dfd07082SAaron DeBattista void mlir::tosa::populateTosaDecomposeTransposeConv(
334dfd07082SAaron DeBattista     MLIRContext *ctx, RewritePatternSet &patterns) {
335b7f4335dSEric Kunze   patterns.add<TransposeConvNonStridedConverter>(ctx);
336b4e0507cSTres Popp   patterns.add<TransposeConvStridedConverter>(ctx);
33754eec7caSRob Suderman }
338