1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically 10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping 11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D 12 // including transposing/reversing/reshaping etc.. 13 // of the weights and input/output tenors and reversing/reshaping etc .. of 14 // the weights 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 20 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 21 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace mlir; 25 using namespace mlir::tosa; 26 27 namespace { 28 29 class TransposeConvNonStridedConverter 30 : public OpRewritePattern<tosa::TransposeConv2DOp> { 31 public: 32 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 33 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 34 PatternRewriter &rewriter) const final { 35 Location loc = op->getLoc(); 36 Value input = op->getOperand(0); 37 Value weight = op->getOperand(1); 38 Value bias = op->getOperand(2); 39 40 ShapedType inputTy = cast<ShapedType>(input.getType()); 41 ShapedType weightTy = cast<ShapedType>(weight.getType()); 42 ShapedType biasTy = cast<ShapedType>(bias.getType()); 43 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); 44 45 llvm::ArrayRef<int64_t> stride = op.getStride(); 46 llvm::ArrayRef<int64_t> pad = op.getOutPad(); 47 48 // If striding is all 1 we can modify padding and reverse the kernel along 49 // the x/y direction to make it a regular convolution. This is much simpler 50 // then handling striding.... 51 if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) 52 return failure(); 53 54 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 55 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 56 return failure(); 57 58 int64_t kernelHeight = weightTy.getDimSize(1); 59 int64_t kernelWidth = weightTy.getDimSize(2); 60 61 llvm::SmallVector<int64_t> convPad(4, 0); 62 convPad[0] = kernelHeight - 1 + pad[0]; 63 convPad[1] = kernelHeight - 1 + pad[1]; 64 convPad[2] = kernelWidth - 1 + pad[2]; 65 convPad[3] = kernelWidth - 1 + pad[3]; 66 67 auto reverse1 = rewriter.create<tosa::ReverseOp>( 68 loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1)); 69 auto reverse2 = rewriter.create<tosa::ReverseOp>( 70 loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); 71 72 Value conv2d; 73 if (op.getQuantizationInfo()) { 74 conv2d = rewriter.create<tosa::Conv2DOp>( 75 loc, resultTy, input, reverse2, bias, 76 rewriter.getDenseI64ArrayAttr(convPad), 77 rewriter.getDenseI64ArrayAttr(stride), 78 rewriter.getDenseI64ArrayAttr({1, 1}), 79 /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); 80 } else { 81 conv2d = rewriter.create<tosa::Conv2DOp>( 82 loc, resultTy, input, reverse2, bias, 83 rewriter.getDenseI64ArrayAttr(convPad), 84 rewriter.getDenseI64ArrayAttr(stride), 85 rewriter.getDenseI64ArrayAttr({1, 1}), 86 /* acc_type = */ op.getAccTypeAttr()); 87 } 88 89 rewriter.replaceOp(op, conv2d); 90 return success(); 91 } 92 }; 93 94 class TransposeConvStridedConverter 95 : public OpRewritePattern<tosa::TransposeConv2DOp> { 96 public: 97 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 98 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 99 PatternRewriter &rewriter) const final { 100 Location loc = op->getLoc(); 101 Value input = op->getOperand(0); 102 Value weight = op->getOperand(1); 103 Value bias = op->getOperand(2); 104 105 ShapedType inputTy = cast<ShapedType>(input.getType()); 106 ShapedType weightTy = cast<ShapedType>(weight.getType()); 107 ShapedType biasTy = cast<ShapedType>(bias.getType()); 108 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); 109 110 Type inputETy = inputTy.getElementType(); 111 Type weightETy = weightTy.getElementType(); 112 Type biasETy = biasTy.getElementType(); 113 Type resultETy = resultTy.getElementType(); 114 115 llvm::ArrayRef<int64_t> pad = op.getOutPad(); 116 llvm::ArrayRef<int64_t> stride = op.getStride(); 117 118 // If striding is all 1 we can modify padding and reverse the kernel along 119 // the x/y direction to make it a regular convolution. This is much simpler 120 // then handling striding.... 121 122 // If strides are all 1 we dont need to use this one. 123 if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) 124 return rewriter.notifyMatchFailure(op, "non-one stride found."); 125 126 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 127 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 128 return failure(); 129 130 int64_t batch = inputTy.getDimSize(0); 131 132 int64_t outputChannels = weightTy.getDimSize(0); 133 int64_t weightHeight = weightTy.getDimSize(1); 134 int64_t weightWidth = weightTy.getDimSize(2); 135 int64_t inputChannels = weightTy.getDimSize(3); 136 137 // Pad the weight so that it is modulo of the striding. 138 llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 139 weightPadding[3] = 140 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0; 141 weightPadding[5] = 142 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; 143 144 Value weightPaddingVal = 145 getTosaConstShape(rewriter, op->getLoc(), weightPadding); 146 147 if (op.getQuantizationInfo().has_value()) { 148 auto quantInfo = op.getQuantizationInfo().value(); 149 weight = CreateOpAndInferShape<tosa::PadOp>( 150 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 151 weightPaddingVal, nullptr, 152 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp())); 153 154 } else { 155 weight = CreateOpAndInferShape<tosa::PadOp>( 156 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 157 weightPaddingVal); 158 } 159 160 weightTy = cast<ShapedType>(weight.getType()); 161 weightHeight = weightTy.getDimSize(1); 162 weightWidth = weightTy.getDimSize(2); 163 164 // Split out the width / height by the stride dimensions. 165 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { 166 outputChannels, weightHeight / stride[0], 167 stride[0], weightWidth / stride[1], 168 stride[1], inputChannels}; 169 weight = CreateOpAndInferShape<tosa::ReshapeOp>( 170 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 171 rewriter.getDenseI64ArrayAttr(weightReshapeDims0)); 172 173 // Transpose the factored-out stride to the output channels. 174 Value transposeWeightVal = rewriter.create<tosa::ConstOp>( 175 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 176 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); 177 178 weight = CreateOpAndInferShape<tosa::TransposeOp>( 179 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 180 transposeWeightVal); 181 182 // Collapse the strides and output channels into a single dimension. 183 llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { 184 outputChannels * stride[0] * stride[1], weightHeight / stride[0], 185 weightWidth / stride[1], inputChannels}; 186 weight = CreateOpAndInferShape<tosa::ReshapeOp>( 187 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 188 rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); 189 ShapedType restridedWeightTy = cast<ShapedType>(weight.getType()); 190 191 weight = CreateOpAndInferShape<tosa::ReverseOp>( 192 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 193 /* axis = */ rewriter.getI32IntegerAttr(1)); 194 weight = CreateOpAndInferShape<tosa::ReverseOp>( 195 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 196 /* axis = */ rewriter.getI32IntegerAttr(2)); 197 198 // We need to pad the input far enough that we can pull all values. 199 llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 200 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; 201 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; 202 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; 203 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; 204 205 Value inputPaddingVal = 206 getTosaConstShape(rewriter, op->getLoc(), inputPadding); 207 208 if (op.getQuantizationInfo().has_value()) { 209 auto quantInfo = op.getQuantizationInfo().value(); 210 input = CreateOpAndInferShape<tosa::PadOp>( 211 rewriter, loc, UnrankedTensorType::get(inputETy), input, 212 inputPaddingVal, nullptr, 213 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp())); 214 } else { 215 input = CreateOpAndInferShape<tosa::PadOp>( 216 rewriter, loc, UnrankedTensorType::get(inputETy), input, 217 inputPaddingVal); 218 } 219 220 // We use a zero bias as we need to broadcast the bias. 221 auto zeroBias = rewriter.create<tosa::ConstOp>( 222 loc, 223 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 224 biasETy), 225 DenseElementsAttr::get( 226 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 227 biasETy), 228 rewriter.getZeroAttr(biasETy))); 229 230 // Perform the convolution using the zero bias. 231 Value conv2d; 232 if (op.getQuantizationInfo()) { 233 conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( 234 rewriter, loc, UnrankedTensorType::get(resultETy), input, 235 weight, zeroBias, 236 /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), 237 /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), 238 /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), 239 /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) 240 .getResult(); 241 } else { 242 conv2d = CreateOpAndInferShape<tosa::Conv2DOp>( 243 rewriter, loc, UnrankedTensorType::get(resultETy), input, 244 weight, zeroBias, 245 /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), 246 /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), 247 /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), 248 /* acc_type = */ op.getAccTypeAttr()) 249 .getResult(); 250 } 251 252 // Factor the resulting width / height. 253 ShapedType convTy = cast<ShapedType>(conv2d.getType()); 254 Type convETy = convTy.getElementType(); 255 256 int64_t convHeight = convTy.getDimSize(1); 257 int64_t convWidth = convTy.getDimSize(2); 258 259 // Factor striding out of the convolution result. 260 llvm::SmallVector<int64_t, 6> convReshapeDims0 = { 261 batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; 262 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>( 263 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 264 rewriter.getDenseI64ArrayAttr(convReshapeDims0)); 265 266 // Transpose the factored-out stride to the output channels. 267 Value transposeConvVal = rewriter.create<tosa::ConstOp>( 268 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 269 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); 270 271 conv2d = CreateOpAndInferShape<tosa::TransposeOp>( 272 rewriter, loc, UnrankedTensorType::get(convETy), conv2d, 273 transposeConvVal); 274 275 // Fuse striding behavior back into width / height. 276 llvm::SmallVector<int64_t, 6> convReshapeDims1 = { 277 batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; 278 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>( 279 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 280 rewriter.getDenseI64ArrayAttr(convReshapeDims1)); 281 282 // Determine the amount to slice / pad from the result start. 283 int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]); 284 int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]); 285 int64_t resultPadTop = std::max<int64_t>(0, pad[0]); 286 int64_t resultPadLeft = std::max<int64_t>(0, pad[2]); 287 288 // Try to slice the targetted result size, cap to the convolutions width. 289 int64_t resultSliceHeight = 290 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop, 291 resultTy.getDimSize(1) - resultPadTop); 292 int64_t resultSliceWidth = 293 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft, 294 resultTy.getDimSize(2) - resultPadLeft); 295 296 llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop, 297 resultSliceLeft, 0}; 298 llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(), 299 convReshapeDims1.end()); 300 sliceSize[1] = resultSliceHeight; 301 sliceSize[2] = resultSliceWidth; 302 303 auto slice = CreateOpAndInferShape<tosa::SliceOp>( 304 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 305 getTosaConstShape(rewriter, loc, sliceBegin), 306 getTosaConstShape(rewriter, loc, sliceSize)) 307 .getResult(); 308 309 llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 310 resultPadding[2] = resultPadTop; 311 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1]; 312 resultPadding[4] = resultPadLeft; 313 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; 314 315 Value resultPaddingVal = 316 getTosaConstShape(rewriter, op->getLoc(), resultPadding); 317 318 Value resultPad = CreateOpAndInferShape<tosa::PadOp>( 319 rewriter, loc, UnrankedTensorType::get(resultETy), slice, 320 resultPaddingVal); 321 322 if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) { 323 return failure(); 324 } 325 326 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias); 327 return success(); 328 } 329 }; 330 331 } // namespace 332 333 void mlir::tosa::populateTosaDecomposeTransposeConv( 334 MLIRContext *ctx, RewritePatternSet &patterns) { 335 patterns.add<TransposeConvNonStridedConverter>(ctx); 336 patterns.add<TransposeConvStridedConverter>(ctx); 337 } 338