xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (revision 7e622b61320543b3706711609f1f32fd9ea3788d)
1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17 
18 using namespace mlir;
19 using namespace mlir::tosa;
20 
21 namespace {
22 
23 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
24   return to_vector(llvm::map_range(shape, [](int64_t dim) {
25     return ShapedType::isDynamic(dim) ? -1 : dim;
26   }));
27 }
28 
29 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
30   explicit Conv2DIsFullyConnected(MLIRContext *context)
31       : OpRewritePattern(context) {}
32 
33   LogicalResult matchAndRewrite(tosa::Conv2DOp op,
34                                 PatternRewriter &rewriter) const override {
35     Value input = op.getInput();
36     Value weight = op.getWeight();
37     ShapedType inputType = cast<ShapedType>(input.getType());
38     ShapedType weightType = cast<ShapedType>(weight.getType());
39     ShapedType resultType = cast<ShapedType>(op.getType());
40 
41     auto numDynamic =
42         llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
43     if (numDynamic > 1)
44       return rewriter.notifyMatchFailure(
45           op, "at most one dim in input may be dynamic");
46     if (!weightType.hasRank())
47       return rewriter.notifyMatchFailure(op, "unranked weight input");
48 
49     if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
50       return failure();
51 
52     // Only works for a 1x1 kernel.
53     ArrayRef<int64_t> weightShape = weightType.getShape();
54     if (weightShape[1] != 1 || weightShape[2] != 1)
55       return failure();
56 
57     llvm::ArrayRef<int64_t> padAttr = op.getPad();
58     llvm::SmallVector<int64_t> pad(8, 0);
59     for (const auto &it : llvm::enumerate(padAttr))
60       pad[it.index() + 2] = it.value();
61 
62     if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
63       Type inputETy = inputType.getElementType();
64       Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
65       if (op.getQuantizationInfo()) {
66         auto quantizationInfo = op.getQuantizationInfo();
67         int64_t iZp = quantizationInfo->getInputZp();
68 
69         if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
70           return rewriter.notifyMatchFailure(
71               op, "tosa.conv op quantization has zp outside of input range");
72 
73         zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
74       }
75 
76       llvm::SmallVector<int64_t> newShape(inputType.getShape());
77 
78       for (int i = 0, s = newShape.size(); i < s; ++i) {
79         if (newShape[i] != ShapedType::kDynamic) {
80           newShape[i] += pad[i * 2] + pad[i * 2 + 1];
81         }
82       }
83 
84       Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
85 
86       auto padTy = RankedTensorType::get({}, inputETy);
87       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
88       Value padVal =
89           rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
90       inputType = RankedTensorType::get(newShape, inputETy);
91       input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
92                                            padSizeVal, padVal);
93     }
94 
95     // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
96     ArrayRef<int64_t> inputShape = inputType.getShape();
97     int64_t combined = ShapedType::kDynamic;
98     if (numDynamic == 0)
99       combined = inputShape[0] * inputShape[1] * inputShape[2];
100     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
101     auto revisedInputShapeType =
102         RankedTensorType::get(revisedInputShape, inputType.getElementType());
103     auto reshapedInput = rewriter
104                              .create<tosa::ReshapeOp>(
105                                  op.getLoc(), revisedInputShapeType, input,
106                                  rewriter.getDenseI64ArrayAttr(
107                                      convertFromMlirShape(revisedInputShape)))
108                              .getResult();
109 
110     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
111     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
112                                                      weightShape[3]};
113     auto revisedWeightShapeType = RankedTensorType::get(
114         revisedWeightShape,
115         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
116     auto reshapedWeight = rewriter
117                               .create<tosa::ReshapeOp>(
118                                   op.getLoc(), revisedWeightShapeType, weight,
119                                   rewriter.getDenseI64ArrayAttr(
120                                       convertFromMlirShape(revisedWeightShape)))
121                               .getResult();
122 
123     // Perform a fully connected network over the reshaped input and weight.
124     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
125     auto fullyConnectedShapeType =
126         RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
127 
128     Value fullyConnectedValue;
129     if (op.getQuantizationInfo()) {
130       fullyConnectedValue =
131           rewriter
132               .create<tosa::FullyConnectedOp>(
133                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
134                   reshapedWeight, op.getBias(), *op.getQuantizationInfo())
135               .getResult();
136     } else {
137       fullyConnectedValue = rewriter
138                                 .create<tosa::FullyConnectedOp>(
139                                     op.getLoc(), fullyConnectedShapeType,
140                                     reshapedInput, reshapedWeight, op.getBias())
141                                 .getResult();
142     }
143 
144     // Reshape output to [N, IH, IW, OC].
145     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
146                                               inputShape[2], weightShape[0]};
147     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
148         op, resultType, fullyConnectedValue,
149         rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
150     return success();
151   }
152 };
153 
154 } // namespace
155 
156 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
157                                              RewritePatternSet &patterns) {
158   patterns.add<Conv2DIsFullyConnected>(ctx);
159 }
160