xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (revision a58e774fba42e13aa00667d644e96b783fc914b4)
1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 using namespace mlir::tosa;
22 
23 namespace {
24 
25 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
26   explicit DepthwiseConv2DIsMul(MLIRContext *context)
27       : OpRewritePattern(context) {}
28 
29   LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
30                                 PatternRewriter &rewriter) const override {
31     Value input = op.getInput();
32     Value weight = op.getWeight();
33     ShapedType inputType = cast<ShapedType>(input.getType());
34     ShapedType weightType = cast<ShapedType>(weight.getType());
35     ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
36 
37     if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
38           resultType.hasStaticShape())) {
39       return failure();
40     }
41 
42     if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
43       return failure();
44 
45     // Only works for a 1x1 kernel.
46     ArrayRef<int64_t> weightShape = weightType.getShape();
47     if (weightShape[0] != 1 || weightShape[1] != 1) {
48       return failure();
49     }
50 
51     // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
52     ArrayRef<int64_t> inputShape = inputType.getShape();
53     llvm::SmallVector<int64_t, 2> revisedInputShape{
54         inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
55     inputType = RankedTensorType::get(
56         revisedInputShape,
57         dyn_cast<RankedTensorType>(input.getType()).getElementType());
58     input = rewriter
59                 .create<tosa::ReshapeOp>(
60                     op.getLoc(), inputType, input,
61                     rewriter.getDenseI64ArrayAttr(revisedInputShape))
62                 .getResult();
63 
64     if (inputType.getElementType() != resultType.getElementType()) {
65       inputType = inputType.clone(resultType.getElementType());
66       input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
67     }
68 
69     if (weightType.getElementType() != resultType.getElementType()) {
70       weightType = weightType.clone(resultType.getElementType());
71       weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
72     }
73 
74     if (auto quantizationInfo = op.getQuantizationInfo()) {
75       auto iZp = quantizationInfo->getInputZp();
76       auto wZp = quantizationInfo->getWeightZp();
77 
78       auto applyZp = [&](Value val, int64_t zp) -> Value {
79         if (zp == 0)
80           return val;
81         auto ety = cast<ShapedType>(val.getType()).getElementType();
82         std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
83                                    1);
84         auto zpTy = RankedTensorType::get(shape, ety);
85         auto zpAttr =
86             DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
87         auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
88         return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
89                                             zpVal);
90       };
91 
92       input = applyZp(input, iZp);
93       weight = applyZp(weight, wZp);
94     }
95 
96     ArrayRef<int64_t> padAttr = op.getPad();
97     llvm::SmallVector<int64_t> pad(10, 0);
98     for (const auto &it : llvm::enumerate(padAttr))
99       pad[it.index() + 2] = it.value();
100 
101     if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
102       Type inputETy = inputType.getElementType();
103       Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
104 
105       llvm::SmallVector<int64_t> newShape(inputType.getShape());
106       for (int i = 0, s = pad.size(); i < s; ++i) {
107         if (newShape[i / 2] != ShapedType::kDynamic) {
108           newShape[i / 2] += pad[i];
109         }
110       }
111 
112       Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
113 
114       auto padTy = RankedTensorType::get({}, inputETy);
115       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
116       Value padVal =
117           rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
118       inputType = RankedTensorType::get(newShape, inputETy);
119       input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
120                                            padSizeVal, padVal);
121     }
122 
123     // Perform an elementwise mul over the reshaped input and weight.
124     llvm::SmallVector<int64_t, 2> mulShape{
125         inputType.getDimSize(0), inputType.getDimSize(1),
126         inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
127     auto mulShapeType = RankedTensorType::get(
128         mulShape,
129         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
130 
131     if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
132       return failure();
133     }
134 
135     auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
136     auto shiftType = RankedTensorType::get({1}, shiftElementType);
137     auto shiftZeroAttr = DenseElementsAttr::get(
138         shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
139     Value constZero =
140         rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr);
141     Value mulValue = rewriter
142                          .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
143                                               weight, constZero)
144                          .getResult();
145 
146     // Reshape output to [N, H, W, C * M].
147     auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
148     auto outputShapeType = RankedTensorType::get(
149         outputShape,
150         dyn_cast<RankedTensorType>(input.getType()).getElementType());
151     Value outputValue = rewriter.create<tosa::ReshapeOp>(
152         op.getLoc(), outputShapeType, mulValue,
153         rewriter.getDenseI64ArrayAttr(outputShape));
154 
155     Value bias = op.getBias();
156     if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
157       return failure();
158     }
159 
160     // Add in the bias.
161     rewriter
162         .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
163         .getResult();
164     return success();
165   }
166 };
167 
168 } // namespace
169 
170 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
171                                                 RewritePatternSet &patterns) {
172   patterns.add<DepthwiseConv2DIsMul>(ctx);
173 }
174