xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 //     of the weights and input/output tenors and reversing/reshaping etc .. of
14 //     the weights
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
21 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace mlir;
25 using namespace mlir::tosa;
26 
27 namespace {
28 
29 class TransposeConvNonStridedConverter
30     : public OpRewritePattern<tosa::TransposeConv2DOp> {
31 public:
32   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
33   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
34                                 PatternRewriter &rewriter) const final {
35     Location loc = op->getLoc();
36     Value input = op->getOperand(0);
37     Value weight = op->getOperand(1);
38     Value bias = op->getOperand(2);
39 
40     ShapedType inputTy = cast<ShapedType>(input.getType());
41     ShapedType weightTy = cast<ShapedType>(weight.getType());
42     ShapedType biasTy = cast<ShapedType>(bias.getType());
43     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
44 
45     llvm::ArrayRef<int64_t> stride = op.getStride();
46     llvm::ArrayRef<int64_t> pad = op.getOutPad();
47 
48     // If striding is all 1 we can modify padding and reverse the kernel along
49     // the x/y direction to make it a regular convolution. This is much simpler
50     // then handling striding....
51     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
52       return failure();
53 
54     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
55         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
56       return failure();
57 
58     int64_t kernelHeight = weightTy.getDimSize(1);
59     int64_t kernelWidth = weightTy.getDimSize(2);
60 
61     llvm::SmallVector<int64_t> convPad(4, 0);
62     convPad[0] = kernelHeight - 1 + pad[0];
63     convPad[1] = kernelHeight - 1 + pad[1];
64     convPad[2] = kernelWidth - 1 + pad[2];
65     convPad[3] = kernelWidth - 1 + pad[3];
66 
67     auto reverse1 = rewriter.create<tosa::ReverseOp>(
68         loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
69     auto reverse2 = rewriter.create<tosa::ReverseOp>(
70         loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
71 
72     Value conv2d;
73     if (op.getQuantizationInfo()) {
74       conv2d = rewriter.create<tosa::Conv2DOp>(
75           loc, resultTy, input, reverse2, bias,
76           rewriter.getDenseI64ArrayAttr(convPad),
77           rewriter.getDenseI64ArrayAttr(stride),
78           rewriter.getDenseI64ArrayAttr({1, 1}),
79           /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
80     } else {
81       conv2d = rewriter.create<tosa::Conv2DOp>(
82           loc, resultTy, input, reverse2, bias,
83           rewriter.getDenseI64ArrayAttr(convPad),
84           rewriter.getDenseI64ArrayAttr(stride),
85           rewriter.getDenseI64ArrayAttr({1, 1}),
86           /* acc_type = */ op.getAccTypeAttr());
87     }
88 
89     rewriter.replaceOp(op, conv2d);
90     return success();
91   }
92 };
93 
94 class TransposeConvStridedConverter
95     : public OpRewritePattern<tosa::TransposeConv2DOp> {
96 public:
97   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
98   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
99                                 PatternRewriter &rewriter) const final {
100     Location loc = op->getLoc();
101     Value input = op->getOperand(0);
102     Value weight = op->getOperand(1);
103     Value bias = op->getOperand(2);
104 
105     ShapedType inputTy = cast<ShapedType>(input.getType());
106     ShapedType weightTy = cast<ShapedType>(weight.getType());
107     ShapedType biasTy = cast<ShapedType>(bias.getType());
108     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
109 
110     Type inputETy = inputTy.getElementType();
111     Type weightETy = weightTy.getElementType();
112     Type biasETy = biasTy.getElementType();
113     Type resultETy = resultTy.getElementType();
114 
115     llvm::ArrayRef<int64_t> pad = op.getOutPad();
116     llvm::ArrayRef<int64_t> stride = op.getStride();
117 
118     // If striding is all 1 we can modify padding and reverse the kernel along
119     // the x/y direction to make it a regular convolution. This is much simpler
120     // then handling striding....
121 
122     // If strides are all 1 we dont need to use this one.
123     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
124       return rewriter.notifyMatchFailure(op, "non-one stride found.");
125 
126     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
127         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
128       return failure();
129 
130     int64_t batch = inputTy.getDimSize(0);
131 
132     int64_t outputChannels = weightTy.getDimSize(0);
133     int64_t weightHeight = weightTy.getDimSize(1);
134     int64_t weightWidth = weightTy.getDimSize(2);
135     int64_t inputChannels = weightTy.getDimSize(3);
136 
137     // Pad the weight so that it is modulo of the striding.
138     llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
139     weightPadding[3] =
140         (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
141     weightPadding[5] =
142         weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
143 
144     Value weightPaddingVal =
145         getTosaConstShape(rewriter, op->getLoc(), weightPadding);
146 
147     if (op.getQuantizationInfo().has_value()) {
148       auto quantInfo = op.getQuantizationInfo().value();
149       weight = CreateOpAndInferShape<tosa::PadOp>(
150           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
151           weightPaddingVal, nullptr,
152           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
153 
154     } else {
155       weight = CreateOpAndInferShape<tosa::PadOp>(
156           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
157           weightPaddingVal);
158     }
159 
160     weightTy = cast<ShapedType>(weight.getType());
161     weightHeight = weightTy.getDimSize(1);
162     weightWidth = weightTy.getDimSize(2);
163 
164     // Split out the width / height by the stride dimensions.
165     llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
166         outputChannels, weightHeight / stride[0],
167         stride[0],      weightWidth / stride[1],
168         stride[1],      inputChannels};
169     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
170         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
171         rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
172 
173     // Transpose the factored-out stride to the output channels.
174     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
175         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
176         rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
177 
178     weight = CreateOpAndInferShape<tosa::TransposeOp>(
179         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
180         transposeWeightVal);
181 
182     // Collapse the strides and output channels into a single dimension.
183     llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
184         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
185         weightWidth / stride[1], inputChannels};
186     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
187         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
188         rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
189     ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
190 
191     weight = CreateOpAndInferShape<tosa::ReverseOp>(
192         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
193         /* axis = */ rewriter.getI32IntegerAttr(1));
194     weight = CreateOpAndInferShape<tosa::ReverseOp>(
195         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
196         /* axis = */ rewriter.getI32IntegerAttr(2));
197 
198     // We need to pad the input far enough that we can pull all values.
199     llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
200     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
201     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
202     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
203     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
204 
205     Value inputPaddingVal =
206         getTosaConstShape(rewriter, op->getLoc(), inputPadding);
207 
208     if (op.getQuantizationInfo().has_value()) {
209       auto quantInfo = op.getQuantizationInfo().value();
210       input = CreateOpAndInferShape<tosa::PadOp>(
211           rewriter, loc, UnrankedTensorType::get(inputETy), input,
212           inputPaddingVal, nullptr,
213           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
214     } else {
215       input = CreateOpAndInferShape<tosa::PadOp>(
216           rewriter, loc, UnrankedTensorType::get(inputETy), input,
217           inputPaddingVal);
218     }
219 
220     // We use a zero bias as we need to broadcast the bias.
221     auto zeroBias = rewriter.create<tosa::ConstOp>(
222         loc,
223         RankedTensorType::get({outputChannels * stride[0] * stride[1]},
224                               biasETy),
225         DenseElementsAttr::get(
226             RankedTensorType::get({outputChannels * stride[0] * stride[1]},
227                                   biasETy),
228             rewriter.getZeroAttr(biasETy)));
229 
230     // Perform the convolution using the zero bias.
231     Value conv2d;
232     if (op.getQuantizationInfo()) {
233       conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
234                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
235                    weight, zeroBias,
236                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
237                    /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
238                    /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
239                    /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
240                    .getResult();
241     } else {
242       conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
243                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
244                    weight, zeroBias,
245                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
246                    /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
247                    /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
248                    /* acc_type = */ op.getAccTypeAttr())
249                    .getResult();
250     }
251 
252     // Factor the resulting width / height.
253     ShapedType convTy = cast<ShapedType>(conv2d.getType());
254     Type convETy = convTy.getElementType();
255 
256     int64_t convHeight = convTy.getDimSize(1);
257     int64_t convWidth = convTy.getDimSize(2);
258 
259     // Factor striding out of the convolution result.
260     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
261         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
262     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
263         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
264         rewriter.getDenseI64ArrayAttr(convReshapeDims0));
265 
266     // Transpose the factored-out stride to the output channels.
267     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
268         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
269         rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
270 
271     conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
272         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
273         transposeConvVal);
274 
275     // Fuse striding behavior back into width / height.
276     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
277         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
278     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
279         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
280         rewriter.getDenseI64ArrayAttr(convReshapeDims1));
281 
282     // Determine the amount to slice / pad from the result start.
283     int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
284     int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
285     int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
286     int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
287 
288     // Try to slice the targetted result size, cap to the convolutions width.
289     int64_t resultSliceHeight =
290         std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
291                           resultTy.getDimSize(1) - resultPadTop);
292     int64_t resultSliceWidth =
293         std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
294                           resultTy.getDimSize(2) - resultPadLeft);
295 
296     llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
297                                                 resultSliceLeft, 0};
298     llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
299                                             convReshapeDims1.end());
300     sliceSize[1] = resultSliceHeight;
301     sliceSize[2] = resultSliceWidth;
302 
303     auto slice = CreateOpAndInferShape<tosa::SliceOp>(
304                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
305                      getTosaConstShape(rewriter, loc, sliceBegin),
306                      getTosaConstShape(rewriter, loc, sliceSize))
307                      .getResult();
308 
309     llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
310     resultPadding[2] = resultPadTop;
311     resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
312     resultPadding[4] = resultPadLeft;
313     resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
314 
315     Value resultPaddingVal =
316         getTosaConstShape(rewriter, op->getLoc(), resultPadding);
317 
318     Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
319         rewriter, loc, UnrankedTensorType::get(resultETy), slice,
320         resultPaddingVal);
321 
322     if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
323       return failure();
324     }
325 
326     rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
327     return success();
328   }
329 };
330 
331 } // namespace
332 
333 void mlir::tosa::populateTosaDecomposeTransposeConv(
334     MLIRContext *ctx, RewritePatternSet &patterns) {
335   patterns.add<TransposeConvNonStridedConverter>(ctx);
336   patterns.add<TransposeConvStridedConverter>(ctx);
337 }
338