13bcaf2ebSGeorgios Pinitas //===- TosaDecomposeTransposeConv.cpp -------------------------------------===// 254eec7caSRob Suderman // 354eec7caSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 454eec7caSRob Suderman // See https://llvm.org/LICENSE.txt for license information. 554eec7caSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 654eec7caSRob Suderman // 754eec7caSRob Suderman //===----------------------------------------------------------------------===// 854eec7caSRob Suderman // 9dfd07082SAaron DeBattista // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically 10dfd07082SAaron DeBattista // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping 11dfd07082SAaron DeBattista // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D 12dfd07082SAaron DeBattista // including transposing/reversing/reshaping etc.. 13dfd07082SAaron DeBattista // of the weights and input/output tenors and reversing/reshaping etc .. of 14dfd07082SAaron DeBattista // the weights 1554eec7caSRob Suderman // 1654eec7caSRob Suderman //===----------------------------------------------------------------------===// 1754eec7caSRob Suderman 18dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/IR/TosaOps.h" 1954eec7caSRob Suderman #include "mlir/Dialect/Tosa/Transforms/Passes.h" 20e0537d1aSTai Ly #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 2154eec7caSRob Suderman #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 2254eec7caSRob Suderman #include "mlir/Pass/Pass.h" 2354eec7caSRob Suderman 2454eec7caSRob Suderman using namespace mlir; 2554eec7caSRob Suderman using namespace mlir::tosa; 2654eec7caSRob Suderman 2754eec7caSRob Suderman namespace { 2854eec7caSRob Suderman 29b7f4335dSEric Kunze class TransposeConvNonStridedConverter 3054eec7caSRob Suderman : public OpRewritePattern<tosa::TransposeConv2DOp> { 3154eec7caSRob Suderman public: 3254eec7caSRob Suderman using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 3354eec7caSRob Suderman LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 3454eec7caSRob Suderman PatternRewriter &rewriter) const final { 3554eec7caSRob Suderman Location loc = op->getLoc(); 3654eec7caSRob Suderman Value input = op->getOperand(0); 3754eec7caSRob Suderman Value weight = op->getOperand(1); 3854eec7caSRob Suderman Value bias = op->getOperand(2); 3954eec7caSRob Suderman 405550c821STres Popp ShapedType inputTy = cast<ShapedType>(input.getType()); 415550c821STres Popp ShapedType weightTy = cast<ShapedType>(weight.getType()); 425550c821STres Popp ShapedType biasTy = cast<ShapedType>(bias.getType()); 435550c821STres Popp ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); 4454eec7caSRob Suderman 4511030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> stride = op.getStride(); 4611030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> pad = op.getOutPad(); 4754eec7caSRob Suderman 4854eec7caSRob Suderman // If striding is all 1 we can modify padding and reverse the kernel along 4954eec7caSRob Suderman // the x/y direction to make it a regular convolution. This is much simpler 5054eec7caSRob Suderman // then handling striding.... 5154eec7caSRob Suderman if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) 5254eec7caSRob Suderman return failure(); 5354eec7caSRob Suderman 5454eec7caSRob Suderman if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 5554eec7caSRob Suderman !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 5654eec7caSRob Suderman return failure(); 5754eec7caSRob Suderman 58b7f4335dSEric Kunze int64_t kernelHeight = weightTy.getDimSize(1); 59b7f4335dSEric Kunze int64_t kernelWidth = weightTy.getDimSize(2); 6054eec7caSRob Suderman 6154eec7caSRob Suderman llvm::SmallVector<int64_t> convPad(4, 0); 62fcbf3fafSRob Suderman convPad[0] = kernelHeight - 1 + pad[0]; 63fcbf3fafSRob Suderman convPad[1] = kernelHeight - 1 + pad[1]; 64fcbf3fafSRob Suderman convPad[2] = kernelWidth - 1 + pad[2]; 65fcbf3fafSRob Suderman convPad[3] = kernelWidth - 1 + pad[3]; 6654eec7caSRob Suderman 6754eec7caSRob Suderman auto reverse1 = rewriter.create<tosa::ReverseOp>( 68fc0d4d53STai Ly loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1)); 6954eec7caSRob Suderman auto reverse2 = rewriter.create<tosa::ReverseOp>( 70fc0d4d53STai Ly loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); 7154eec7caSRob Suderman 7254eec7caSRob Suderman Value conv2d; 7313448db0SJacques Pienaar if (op.getQuantizationInfo()) { 7454eec7caSRob Suderman conv2d = rewriter.create<tosa::Conv2DOp>( 7554eec7caSRob Suderman loc, resultTy, input, reverse2, bias, 7611030c7dSAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(convPad), 7711030c7dSAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(stride), 78360a03c9SJack Frankland rewriter.getDenseI64ArrayAttr({1, 1}), 79360a03c9SJack Frankland /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); 8054eec7caSRob Suderman } else { 8154eec7caSRob Suderman conv2d = rewriter.create<tosa::Conv2DOp>( 8254eec7caSRob Suderman loc, resultTy, input, reverse2, bias, 8311030c7dSAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(convPad), 8411030c7dSAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(stride), 85360a03c9SJack Frankland rewriter.getDenseI64ArrayAttr({1, 1}), 86360a03c9SJack Frankland /* acc_type = */ op.getAccTypeAttr()); 8754eec7caSRob Suderman } 8854eec7caSRob Suderman 8954eec7caSRob Suderman rewriter.replaceOp(op, conv2d); 9054eec7caSRob Suderman return success(); 9154eec7caSRob Suderman } 9254eec7caSRob Suderman }; 9354eec7caSRob Suderman 9454eec7caSRob Suderman class TransposeConvStridedConverter 9554eec7caSRob Suderman : public OpRewritePattern<tosa::TransposeConv2DOp> { 9654eec7caSRob Suderman public: 9754eec7caSRob Suderman using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 9854eec7caSRob Suderman LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 9954eec7caSRob Suderman PatternRewriter &rewriter) const final { 10054eec7caSRob Suderman Location loc = op->getLoc(); 10154eec7caSRob Suderman Value input = op->getOperand(0); 10254eec7caSRob Suderman Value weight = op->getOperand(1); 10354eec7caSRob Suderman Value bias = op->getOperand(2); 10454eec7caSRob Suderman 1055550c821STres Popp ShapedType inputTy = cast<ShapedType>(input.getType()); 1065550c821STres Popp ShapedType weightTy = cast<ShapedType>(weight.getType()); 1075550c821STres Popp ShapedType biasTy = cast<ShapedType>(bias.getType()); 1085550c821STres Popp ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); 10954eec7caSRob Suderman 11054eec7caSRob Suderman Type inputETy = inputTy.getElementType(); 11154eec7caSRob Suderman Type weightETy = weightTy.getElementType(); 11254eec7caSRob Suderman Type biasETy = biasTy.getElementType(); 11354eec7caSRob Suderman Type resultETy = resultTy.getElementType(); 11454eec7caSRob Suderman 11511030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> pad = op.getOutPad(); 11611030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> stride = op.getStride(); 11754eec7caSRob Suderman 11854eec7caSRob Suderman // If striding is all 1 we can modify padding and reverse the kernel along 11954eec7caSRob Suderman // the x/y direction to make it a regular convolution. This is much simpler 12054eec7caSRob Suderman // then handling striding.... 12154eec7caSRob Suderman 12254eec7caSRob Suderman // If strides are all 1 we dont need to use this one. 12354eec7caSRob Suderman if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) 124fcbf3fafSRob Suderman return rewriter.notifyMatchFailure(op, "non-one stride found."); 12554eec7caSRob Suderman 12654eec7caSRob Suderman if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 12754eec7caSRob Suderman !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 12854eec7caSRob Suderman return failure(); 12954eec7caSRob Suderman 13054eec7caSRob Suderman int64_t batch = inputTy.getDimSize(0); 13154eec7caSRob Suderman 13254eec7caSRob Suderman int64_t outputChannels = weightTy.getDimSize(0); 13354eec7caSRob Suderman int64_t weightHeight = weightTy.getDimSize(1); 13454eec7caSRob Suderman int64_t weightWidth = weightTy.getDimSize(2); 13554eec7caSRob Suderman int64_t inputChannels = weightTy.getDimSize(3); 13654eec7caSRob Suderman 13754eec7caSRob Suderman // Pad the weight so that it is modulo of the striding. 1387e622b61SJerry-Ge llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 13954eec7caSRob Suderman weightPadding[3] = 140a98a6e95Sluolent (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0; 14154eec7caSRob Suderman weightPadding[5] = 1427e622b61SJerry-Ge weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; 1437e622b61SJerry-Ge 1447e622b61SJerry-Ge Value weightPaddingVal = 1457e622b61SJerry-Ge getTosaConstShape(rewriter, op->getLoc(), weightPadding); 14654eec7caSRob Suderman 14713448db0SJacques Pienaar if (op.getQuantizationInfo().has_value()) { 14813448db0SJacques Pienaar auto quantInfo = op.getQuantizationInfo().value(); 149c8834527STai Ly weight = CreateOpAndInferShape<tosa::PadOp>( 15054eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(weightETy), weight, 15154eec7caSRob Suderman weightPaddingVal, nullptr, 152baca1c1aSMogball rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp())); 15354eec7caSRob Suderman 15454eec7caSRob Suderman } else { 155c8834527STai Ly weight = CreateOpAndInferShape<tosa::PadOp>( 156c8834527STai Ly rewriter, loc, UnrankedTensorType::get(weightETy), weight, 157c8834527STai Ly weightPaddingVal); 15854eec7caSRob Suderman } 15954eec7caSRob Suderman 1605550c821STres Popp weightTy = cast<ShapedType>(weight.getType()); 16154eec7caSRob Suderman weightHeight = weightTy.getDimSize(1); 16254eec7caSRob Suderman weightWidth = weightTy.getDimSize(2); 16354eec7caSRob Suderman 16454eec7caSRob Suderman // Split out the width / height by the stride dimensions. 16554eec7caSRob Suderman llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { 16654eec7caSRob Suderman outputChannels, weightHeight / stride[0], 16754eec7caSRob Suderman stride[0], weightWidth / stride[1], 16854eec7caSRob Suderman stride[1], inputChannels}; 169c8834527STai Ly weight = CreateOpAndInferShape<tosa::ReshapeOp>( 17054eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(weightETy), weight, 1719e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(weightReshapeDims0)); 17254eec7caSRob Suderman 17354eec7caSRob Suderman // Transpose the factored-out stride to the output channels. 17454eec7caSRob Suderman Value transposeWeightVal = rewriter.create<tosa::ConstOp>( 17554eec7caSRob Suderman loc, RankedTensorType::get({6}, rewriter.getI32Type()), 17654eec7caSRob Suderman rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); 17754eec7caSRob Suderman 178c8834527STai Ly weight = CreateOpAndInferShape<tosa::TransposeOp>( 17954eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(weightETy), weight, 18054eec7caSRob Suderman transposeWeightVal); 18154eec7caSRob Suderman 18254eec7caSRob Suderman // Collapse the strides and output channels into a single dimension. 18354eec7caSRob Suderman llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { 18454eec7caSRob Suderman outputChannels * stride[0] * stride[1], weightHeight / stride[0], 18554eec7caSRob Suderman weightWidth / stride[1], inputChannels}; 186c8834527STai Ly weight = CreateOpAndInferShape<tosa::ReshapeOp>( 18754eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(weightETy), weight, 1889e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); 1895550c821STres Popp ShapedType restridedWeightTy = cast<ShapedType>(weight.getType()); 19054eec7caSRob Suderman 191c8834527STai Ly weight = CreateOpAndInferShape<tosa::ReverseOp>( 19254eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(weightETy), weight, 193fc0d4d53STai Ly /* axis = */ rewriter.getI32IntegerAttr(1)); 194c8834527STai Ly weight = CreateOpAndInferShape<tosa::ReverseOp>( 19554eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(weightETy), weight, 196fc0d4d53STai Ly /* axis = */ rewriter.getI32IntegerAttr(2)); 19754eec7caSRob Suderman 19854eec7caSRob Suderman // We need to pad the input far enough that we can pull all values. 1997e622b61SJerry-Ge llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 20054eec7caSRob Suderman inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; 20154eec7caSRob Suderman inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; 20254eec7caSRob Suderman inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; 20354eec7caSRob Suderman inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; 20454eec7caSRob Suderman 2057e622b61SJerry-Ge Value inputPaddingVal = 2067e622b61SJerry-Ge getTosaConstShape(rewriter, op->getLoc(), inputPadding); 20754eec7caSRob Suderman 20813448db0SJacques Pienaar if (op.getQuantizationInfo().has_value()) { 20913448db0SJacques Pienaar auto quantInfo = op.getQuantizationInfo().value(); 210c8834527STai Ly input = CreateOpAndInferShape<tosa::PadOp>( 21154eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(inputETy), input, 21254eec7caSRob Suderman inputPaddingVal, nullptr, 213baca1c1aSMogball rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp())); 21454eec7caSRob Suderman } else { 215c8834527STai Ly input = CreateOpAndInferShape<tosa::PadOp>( 216c8834527STai Ly rewriter, loc, UnrankedTensorType::get(inputETy), input, 217c8834527STai Ly inputPaddingVal); 21854eec7caSRob Suderman } 21954eec7caSRob Suderman 22054eec7caSRob Suderman // We use a zero bias as we need to broadcast the bias. 22154eec7caSRob Suderman auto zeroBias = rewriter.create<tosa::ConstOp>( 22254eec7caSRob Suderman loc, 22354eec7caSRob Suderman RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 22454eec7caSRob Suderman biasETy), 22554eec7caSRob Suderman DenseElementsAttr::get( 22654eec7caSRob Suderman RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 22754eec7caSRob Suderman biasETy), 22854eec7caSRob Suderman rewriter.getZeroAttr(biasETy))); 22954eec7caSRob Suderman 23054eec7caSRob Suderman // Perform the convolution using the zero bias. 23154eec7caSRob Suderman Value conv2d; 23213448db0SJacques Pienaar if (op.getQuantizationInfo()) { 233c8834527STai Ly conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( 23454eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(resultETy), input, 23554eec7caSRob Suderman weight, zeroBias, 23611030c7dSAlexander Shaposhnikov /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), 23711030c7dSAlexander Shaposhnikov /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), 23811030c7dSAlexander Shaposhnikov /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), 239360a03c9SJack Frankland /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) 24054eec7caSRob Suderman .getResult(); 24154eec7caSRob Suderman } else { 242c8834527STai Ly conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( 24354eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(resultETy), input, 24454eec7caSRob Suderman weight, zeroBias, 24511030c7dSAlexander Shaposhnikov /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), 24611030c7dSAlexander Shaposhnikov /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), 247360a03c9SJack Frankland /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), 248360a03c9SJack Frankland /* acc_type = */ op.getAccTypeAttr()) 24954eec7caSRob Suderman .getResult(); 25054eec7caSRob Suderman } 25154eec7caSRob Suderman 25254eec7caSRob Suderman // Factor the resulting width / height. 2535550c821STres Popp ShapedType convTy = cast<ShapedType>(conv2d.getType()); 25454eec7caSRob Suderman Type convETy = convTy.getElementType(); 25554eec7caSRob Suderman 25654eec7caSRob Suderman int64_t convHeight = convTy.getDimSize(1); 25754eec7caSRob Suderman int64_t convWidth = convTy.getDimSize(2); 25854eec7caSRob Suderman 25954eec7caSRob Suderman // Factor striding out of the convolution result. 26054eec7caSRob Suderman llvm::SmallVector<int64_t, 6> convReshapeDims0 = { 26154eec7caSRob Suderman batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; 262c8834527STai Ly conv2d = CreateOpAndInferShape<tosa::ReshapeOp>( 26354eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 2649e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(convReshapeDims0)); 26554eec7caSRob Suderman 26654eec7caSRob Suderman // Transpose the factored-out stride to the output channels. 26754eec7caSRob Suderman Value transposeConvVal = rewriter.create<tosa::ConstOp>( 26854eec7caSRob Suderman loc, RankedTensorType::get({6}, rewriter.getI32Type()), 26954eec7caSRob Suderman rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); 27054eec7caSRob Suderman 271c8834527STai Ly conv2d = CreateOpAndInferShape<tosa::TransposeOp>( 27254eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(convETy), conv2d, 27354eec7caSRob Suderman transposeConvVal); 27454eec7caSRob Suderman 27554eec7caSRob Suderman // Fuse striding behavior back into width / height. 27654eec7caSRob Suderman llvm::SmallVector<int64_t, 6> convReshapeDims1 = { 27754eec7caSRob Suderman batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; 278c8834527STai Ly conv2d = CreateOpAndInferShape<tosa::ReshapeOp>( 27954eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 2809e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(convReshapeDims1)); 28154eec7caSRob Suderman 282fcbf3fafSRob Suderman // Determine the amount to slice / pad from the result start. 283fcbf3fafSRob Suderman int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]); 284fcbf3fafSRob Suderman int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]); 285fcbf3fafSRob Suderman int64_t resultPadTop = std::max<int64_t>(0, pad[0]); 286fcbf3fafSRob Suderman int64_t resultPadLeft = std::max<int64_t>(0, pad[2]); 287fcbf3fafSRob Suderman 288fcbf3fafSRob Suderman // Try to slice the targetted result size, cap to the convolutions width. 289fcbf3fafSRob Suderman int64_t resultSliceHeight = 290fcbf3fafSRob Suderman std::min<int64_t>(convReshapeDims1[1] - resultSliceTop, 291fcbf3fafSRob Suderman resultTy.getDimSize(1) - resultPadTop); 292fcbf3fafSRob Suderman int64_t resultSliceWidth = 293fcbf3fafSRob Suderman std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft, 294fcbf3fafSRob Suderman resultTy.getDimSize(2) - resultPadLeft); 295fcbf3fafSRob Suderman 296fcbf3fafSRob Suderman llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop, 297fcbf3fafSRob Suderman resultSliceLeft, 0}; 298fcbf3fafSRob Suderman llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(), 299fcbf3fafSRob Suderman convReshapeDims1.end()); 300fcbf3fafSRob Suderman sliceSize[1] = resultSliceHeight; 301fcbf3fafSRob Suderman sliceSize[2] = resultSliceWidth; 30254eec7caSRob Suderman 303c8834527STai Ly auto slice = CreateOpAndInferShape<tosa::SliceOp>( 30454eec7caSRob Suderman rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 305*956c0707SJerry-Ge getTosaConstShape(rewriter, loc, sliceBegin), 306*956c0707SJerry-Ge getTosaConstShape(rewriter, loc, sliceSize)) 30754eec7caSRob Suderman .getResult(); 30854eec7caSRob Suderman 3097e622b61SJerry-Ge llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 310fcbf3fafSRob Suderman resultPadding[2] = resultPadTop; 311fcbf3fafSRob Suderman resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1]; 312fcbf3fafSRob Suderman resultPadding[4] = resultPadLeft; 313fcbf3fafSRob Suderman resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; 31454eec7caSRob Suderman 3157e622b61SJerry-Ge Value resultPaddingVal = 3167e622b61SJerry-Ge getTosaConstShape(rewriter, op->getLoc(), resultPadding); 317fcbf3fafSRob Suderman 318c8834527STai Ly Value resultPad = CreateOpAndInferShape<tosa::PadOp>( 319fcbf3fafSRob Suderman rewriter, loc, UnrankedTensorType::get(resultETy), slice, 320fcbf3fafSRob Suderman resultPaddingVal); 321fcbf3fafSRob Suderman 322e0537d1aSTai Ly if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) { 323e0537d1aSTai Ly return failure(); 324e0537d1aSTai Ly } 325e0537d1aSTai Ly 326fcbf3fafSRob Suderman rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias); 32754eec7caSRob Suderman return success(); 32854eec7caSRob Suderman } 32954eec7caSRob Suderman }; 33054eec7caSRob Suderman 331be0a7e9fSMehdi Amini } // namespace 33254eec7caSRob Suderman 333dfd07082SAaron DeBattista void mlir::tosa::populateTosaDecomposeTransposeConv( 334dfd07082SAaron DeBattista MLIRContext *ctx, RewritePatternSet &patterns) { 335b7f4335dSEric Kunze patterns.add<TransposeConvNonStridedConverter>(ctx); 336b4e0507cSTres Popp patterns.add<TransposeConvStridedConverter>(ctx); 33754eec7caSRob Suderman } 338