xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (revision a58e774fba42e13aa00667d644e96b783fc914b4)
13bcaf2ebSGeorgios Pinitas //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2dfd07082SAaron DeBattista //
3dfd07082SAaron DeBattista // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4dfd07082SAaron DeBattista // See https://llvm.org/LICENSE.txt for license information.
5dfd07082SAaron DeBattista // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6dfd07082SAaron DeBattista //
7dfd07082SAaron DeBattista //===----------------------------------------------------------------------===//
8dfd07082SAaron DeBattista //
9dfd07082SAaron DeBattista // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10dfd07082SAaron DeBattista // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11dfd07082SAaron DeBattista //
12dfd07082SAaron DeBattista //===----------------------------------------------------------------------===//
13dfd07082SAaron DeBattista 
14dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/Transforms/Passes.h"
16e0537d1aSTai Ly #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17*a58e774fSJack Frankland #include "mlir/IR/BuiltinTypes.h"
18dfd07082SAaron DeBattista #include "mlir/Pass/Pass.h"
19dfd07082SAaron DeBattista 
20dfd07082SAaron DeBattista using namespace mlir;
21dfd07082SAaron DeBattista using namespace mlir::tosa;
22dfd07082SAaron DeBattista 
23dfd07082SAaron DeBattista namespace {
24dfd07082SAaron DeBattista 
25dfd07082SAaron DeBattista struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
26dfd07082SAaron DeBattista   explicit DepthwiseConv2DIsMul(MLIRContext *context)
27dfd07082SAaron DeBattista       : OpRewritePattern(context) {}
28dfd07082SAaron DeBattista 
29dfd07082SAaron DeBattista   LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
30dfd07082SAaron DeBattista                                 PatternRewriter &rewriter) const override {
3113448db0SJacques Pienaar     Value input = op.getInput();
3213448db0SJacques Pienaar     Value weight = op.getWeight();
335550c821STres Popp     ShapedType inputType = cast<ShapedType>(input.getType());
345550c821STres Popp     ShapedType weightType = cast<ShapedType>(weight.getType());
355550c821STres Popp     ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
36dfd07082SAaron DeBattista 
37dfd07082SAaron DeBattista     if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
38dfd07082SAaron DeBattista           resultType.hasStaticShape())) {
39dfd07082SAaron DeBattista       return failure();
40dfd07082SAaron DeBattista     }
41dfd07082SAaron DeBattista 
4211030c7dSAlexander Shaposhnikov     if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
43dfd07082SAaron DeBattista       return failure();
44dfd07082SAaron DeBattista 
45dfd07082SAaron DeBattista     // Only works for a 1x1 kernel.
46dfd07082SAaron DeBattista     ArrayRef<int64_t> weightShape = weightType.getShape();
47dfd07082SAaron DeBattista     if (weightShape[0] != 1 || weightShape[1] != 1) {
48dfd07082SAaron DeBattista       return failure();
49dfd07082SAaron DeBattista     }
50dfd07082SAaron DeBattista 
51dfd07082SAaron DeBattista     // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
52dfd07082SAaron DeBattista     ArrayRef<int64_t> inputShape = inputType.getShape();
53dfd07082SAaron DeBattista     llvm::SmallVector<int64_t, 2> revisedInputShape{
54dfd07082SAaron DeBattista         inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
5569c984b6SRob Suderman     inputType = RankedTensorType::get(
56dfd07082SAaron DeBattista         revisedInputShape,
575550c821STres Popp         dyn_cast<RankedTensorType>(input.getType()).getElementType());
5869c984b6SRob Suderman     input = rewriter
59dfd07082SAaron DeBattista                 .create<tosa::ReshapeOp>(
6069c984b6SRob Suderman                     op.getLoc(), inputType, input,
619e1a3441SAlexander Shaposhnikov                     rewriter.getDenseI64ArrayAttr(revisedInputShape))
62dfd07082SAaron DeBattista                 .getResult();
63dfd07082SAaron DeBattista 
6469c984b6SRob Suderman     if (inputType.getElementType() != resultType.getElementType()) {
6569c984b6SRob Suderman       inputType = inputType.clone(resultType.getElementType());
6669c984b6SRob Suderman       input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
6769c984b6SRob Suderman     }
6869c984b6SRob Suderman 
6969c984b6SRob Suderman     if (weightType.getElementType() != resultType.getElementType()) {
7069c984b6SRob Suderman       weightType = weightType.clone(resultType.getElementType());
7169c984b6SRob Suderman       weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
7269c984b6SRob Suderman     }
7369c984b6SRob Suderman 
7469c984b6SRob Suderman     if (auto quantizationInfo = op.getQuantizationInfo()) {
7569c984b6SRob Suderman       auto iZp = quantizationInfo->getInputZp();
7669c984b6SRob Suderman       auto wZp = quantizationInfo->getWeightZp();
7769c984b6SRob Suderman 
7869c984b6SRob Suderman       auto applyZp = [&](Value val, int64_t zp) -> Value {
7969c984b6SRob Suderman         if (zp == 0)
8069c984b6SRob Suderman           return val;
815550c821STres Popp         auto ety = cast<ShapedType>(val.getType()).getElementType();
82e0537d1aSTai Ly         std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
83e0537d1aSTai Ly                                    1);
84e0537d1aSTai Ly         auto zpTy = RankedTensorType::get(shape, ety);
8569c984b6SRob Suderman         auto zpAttr =
8669c984b6SRob Suderman             DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
8769c984b6SRob Suderman         auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
8869c984b6SRob Suderman         return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
8969c984b6SRob Suderman                                             zpVal);
9069c984b6SRob Suderman       };
9169c984b6SRob Suderman 
9269c984b6SRob Suderman       input = applyZp(input, iZp);
9369c984b6SRob Suderman       weight = applyZp(weight, wZp);
9469c984b6SRob Suderman     }
9569c984b6SRob Suderman 
9611030c7dSAlexander Shaposhnikov     ArrayRef<int64_t> padAttr = op.getPad();
9769c984b6SRob Suderman     llvm::SmallVector<int64_t> pad(10, 0);
9811030c7dSAlexander Shaposhnikov     for (const auto &it : llvm::enumerate(padAttr))
9911030c7dSAlexander Shaposhnikov       pad[it.index() + 2] = it.value();
10069c984b6SRob Suderman 
10169c984b6SRob Suderman     if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
10269c984b6SRob Suderman       Type inputETy = inputType.getElementType();
10369c984b6SRob Suderman       Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
10469c984b6SRob Suderman 
10569c984b6SRob Suderman       llvm::SmallVector<int64_t> newShape(inputType.getShape());
10669c984b6SRob Suderman       for (int i = 0, s = pad.size(); i < s; ++i) {
10769c984b6SRob Suderman         if (newShape[i / 2] != ShapedType::kDynamic) {
10869c984b6SRob Suderman           newShape[i / 2] += pad[i];
10969c984b6SRob Suderman         }
11069c984b6SRob Suderman       }
11169c984b6SRob Suderman 
1127e622b61SJerry-Ge       Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
11369c984b6SRob Suderman 
11469c984b6SRob Suderman       auto padTy = RankedTensorType::get({}, inputETy);
11569c984b6SRob Suderman       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
11669c984b6SRob Suderman       Value padVal =
11769c984b6SRob Suderman           rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
11869c984b6SRob Suderman       inputType = RankedTensorType::get(newShape, inputETy);
11969c984b6SRob Suderman       input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
12069c984b6SRob Suderman                                            padSizeVal, padVal);
12169c984b6SRob Suderman     }
122dfd07082SAaron DeBattista 
123dfd07082SAaron DeBattista     // Perform an elementwise mul over the reshaped input and weight.
12469c984b6SRob Suderman     llvm::SmallVector<int64_t, 2> mulShape{
12569c984b6SRob Suderman         inputType.getDimSize(0), inputType.getDimSize(1),
12669c984b6SRob Suderman         inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
127dfd07082SAaron DeBattista     auto mulShapeType = RankedTensorType::get(
128dfd07082SAaron DeBattista         mulShape,
1295550c821STres Popp         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
130e0537d1aSTai Ly 
131e0537d1aSTai Ly     if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
132e0537d1aSTai Ly       return failure();
133e0537d1aSTai Ly     }
134e0537d1aSTai Ly 
135*a58e774fSJack Frankland     auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
136*a58e774fSJack Frankland     auto shiftType = RankedTensorType::get({1}, shiftElementType);
137*a58e774fSJack Frankland     auto shiftZeroAttr = DenseElementsAttr::get(
138*a58e774fSJack Frankland         shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
139*a58e774fSJack Frankland     Value constZero =
140*a58e774fSJack Frankland         rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr);
14169c984b6SRob Suderman     Value mulValue = rewriter
14269c984b6SRob Suderman                          .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
143*a58e774fSJack Frankland                                               weight, constZero)
144dfd07082SAaron DeBattista                          .getResult();
145dfd07082SAaron DeBattista 
146dfd07082SAaron DeBattista     // Reshape output to [N, H, W, C * M].
1475550c821STres Popp     auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
148dfd07082SAaron DeBattista     auto outputShapeType = RankedTensorType::get(
149dfd07082SAaron DeBattista         outputShape,
1505550c821STres Popp         dyn_cast<RankedTensorType>(input.getType()).getElementType());
151e0537d1aSTai Ly     Value outputValue = rewriter.create<tosa::ReshapeOp>(
1529e1a3441SAlexander Shaposhnikov         op.getLoc(), outputShapeType, mulValue,
1539e1a3441SAlexander Shaposhnikov         rewriter.getDenseI64ArrayAttr(outputShape));
154dfd07082SAaron DeBattista 
155e0537d1aSTai Ly     Value bias = op.getBias();
156e0537d1aSTai Ly     if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
157e0537d1aSTai Ly       return failure();
158e0537d1aSTai Ly     }
159e0537d1aSTai Ly 
160dfd07082SAaron DeBattista     // Add in the bias.
161dfd07082SAaron DeBattista     rewriter
162e0537d1aSTai Ly         .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
163dfd07082SAaron DeBattista         .getResult();
164dfd07082SAaron DeBattista     return success();
165dfd07082SAaron DeBattista   }
166dfd07082SAaron DeBattista };
167dfd07082SAaron DeBattista 
168dfd07082SAaron DeBattista } // namespace
169dfd07082SAaron DeBattista 
170dfd07082SAaron DeBattista void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
171dfd07082SAaron DeBattista                                                 RewritePatternSet &patterns) {
172b4e0507cSTres Popp   patterns.add<DepthwiseConv2DIsMul>(ctx);
173dfd07082SAaron DeBattista }
174