13bcaf2ebSGeorgios Pinitas //===- TosaDecomposeDepthwise.cpp -----------------------------------------===// 2dfd07082SAaron DeBattista // 3dfd07082SAaron DeBattista // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4dfd07082SAaron DeBattista // See https://llvm.org/LICENSE.txt for license information. 5dfd07082SAaron DeBattista // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6dfd07082SAaron DeBattista // 7dfd07082SAaron DeBattista //===----------------------------------------------------------------------===// 8dfd07082SAaron DeBattista // 9dfd07082SAaron DeBattista // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically 10dfd07082SAaron DeBattista // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add 11dfd07082SAaron DeBattista // 12dfd07082SAaron DeBattista //===----------------------------------------------------------------------===// 13dfd07082SAaron DeBattista 14dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16e0537d1aSTai Ly #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 17*a58e774fSJack Frankland #include "mlir/IR/BuiltinTypes.h" 18dfd07082SAaron DeBattista #include "mlir/Pass/Pass.h" 19dfd07082SAaron DeBattista 20dfd07082SAaron DeBattista using namespace mlir; 21dfd07082SAaron DeBattista using namespace mlir::tosa; 22dfd07082SAaron DeBattista 23dfd07082SAaron DeBattista namespace { 24dfd07082SAaron DeBattista 25dfd07082SAaron DeBattista struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { 26dfd07082SAaron DeBattista explicit DepthwiseConv2DIsMul(MLIRContext *context) 27dfd07082SAaron DeBattista : OpRewritePattern(context) {} 28dfd07082SAaron DeBattista 29dfd07082SAaron DeBattista LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, 30dfd07082SAaron DeBattista PatternRewriter &rewriter) const override { 3113448db0SJacques Pienaar Value input = op.getInput(); 3213448db0SJacques Pienaar Value weight = op.getWeight(); 335550c821STres Popp ShapedType inputType = cast<ShapedType>(input.getType()); 345550c821STres Popp ShapedType weightType = cast<ShapedType>(weight.getType()); 355550c821STres Popp ShapedType resultType = cast<ShapedType>(op.getOutput().getType()); 36dfd07082SAaron DeBattista 37dfd07082SAaron DeBattista if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && 38dfd07082SAaron DeBattista resultType.hasStaticShape())) { 39dfd07082SAaron DeBattista return failure(); 40dfd07082SAaron DeBattista } 41dfd07082SAaron DeBattista 4211030c7dSAlexander Shaposhnikov if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) 43dfd07082SAaron DeBattista return failure(); 44dfd07082SAaron DeBattista 45dfd07082SAaron DeBattista // Only works for a 1x1 kernel. 46dfd07082SAaron DeBattista ArrayRef<int64_t> weightShape = weightType.getShape(); 47dfd07082SAaron DeBattista if (weightShape[0] != 1 || weightShape[1] != 1) { 48dfd07082SAaron DeBattista return failure(); 49dfd07082SAaron DeBattista } 50dfd07082SAaron DeBattista 51dfd07082SAaron DeBattista // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. 52dfd07082SAaron DeBattista ArrayRef<int64_t> inputShape = inputType.getShape(); 53dfd07082SAaron DeBattista llvm::SmallVector<int64_t, 2> revisedInputShape{ 54dfd07082SAaron DeBattista inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; 5569c984b6SRob Suderman inputType = RankedTensorType::get( 56dfd07082SAaron DeBattista revisedInputShape, 575550c821STres Popp dyn_cast<RankedTensorType>(input.getType()).getElementType()); 5869c984b6SRob Suderman input = rewriter 59dfd07082SAaron DeBattista .create<tosa::ReshapeOp>( 6069c984b6SRob Suderman op.getLoc(), inputType, input, 619e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(revisedInputShape)) 62dfd07082SAaron DeBattista .getResult(); 63dfd07082SAaron DeBattista 6469c984b6SRob Suderman if (inputType.getElementType() != resultType.getElementType()) { 6569c984b6SRob Suderman inputType = inputType.clone(resultType.getElementType()); 6669c984b6SRob Suderman input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input); 6769c984b6SRob Suderman } 6869c984b6SRob Suderman 6969c984b6SRob Suderman if (weightType.getElementType() != resultType.getElementType()) { 7069c984b6SRob Suderman weightType = weightType.clone(resultType.getElementType()); 7169c984b6SRob Suderman weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight); 7269c984b6SRob Suderman } 7369c984b6SRob Suderman 7469c984b6SRob Suderman if (auto quantizationInfo = op.getQuantizationInfo()) { 7569c984b6SRob Suderman auto iZp = quantizationInfo->getInputZp(); 7669c984b6SRob Suderman auto wZp = quantizationInfo->getWeightZp(); 7769c984b6SRob Suderman 7869c984b6SRob Suderman auto applyZp = [&](Value val, int64_t zp) -> Value { 7969c984b6SRob Suderman if (zp == 0) 8069c984b6SRob Suderman return val; 815550c821STres Popp auto ety = cast<ShapedType>(val.getType()).getElementType(); 82e0537d1aSTai Ly std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(), 83e0537d1aSTai Ly 1); 84e0537d1aSTai Ly auto zpTy = RankedTensorType::get(shape, ety); 8569c984b6SRob Suderman auto zpAttr = 8669c984b6SRob Suderman DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); 8769c984b6SRob Suderman auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr); 8869c984b6SRob Suderman return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val, 8969c984b6SRob Suderman zpVal); 9069c984b6SRob Suderman }; 9169c984b6SRob Suderman 9269c984b6SRob Suderman input = applyZp(input, iZp); 9369c984b6SRob Suderman weight = applyZp(weight, wZp); 9469c984b6SRob Suderman } 9569c984b6SRob Suderman 9611030c7dSAlexander Shaposhnikov ArrayRef<int64_t> padAttr = op.getPad(); 9769c984b6SRob Suderman llvm::SmallVector<int64_t> pad(10, 0); 9811030c7dSAlexander Shaposhnikov for (const auto &it : llvm::enumerate(padAttr)) 9911030c7dSAlexander Shaposhnikov pad[it.index() + 2] = it.value(); 10069c984b6SRob Suderman 10169c984b6SRob Suderman if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { 10269c984b6SRob Suderman Type inputETy = inputType.getElementType(); 10369c984b6SRob Suderman Attribute zeroAttr = rewriter.getZeroAttr(inputETy); 10469c984b6SRob Suderman 10569c984b6SRob Suderman llvm::SmallVector<int64_t> newShape(inputType.getShape()); 10669c984b6SRob Suderman for (int i = 0, s = pad.size(); i < s; ++i) { 10769c984b6SRob Suderman if (newShape[i / 2] != ShapedType::kDynamic) { 10869c984b6SRob Suderman newShape[i / 2] += pad[i]; 10969c984b6SRob Suderman } 11069c984b6SRob Suderman } 11169c984b6SRob Suderman 1127e622b61SJerry-Ge Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad); 11369c984b6SRob Suderman 11469c984b6SRob Suderman auto padTy = RankedTensorType::get({}, inputETy); 11569c984b6SRob Suderman auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); 11669c984b6SRob Suderman Value padVal = 11769c984b6SRob Suderman rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr); 11869c984b6SRob Suderman inputType = RankedTensorType::get(newShape, inputETy); 11969c984b6SRob Suderman input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input, 12069c984b6SRob Suderman padSizeVal, padVal); 12169c984b6SRob Suderman } 122dfd07082SAaron DeBattista 123dfd07082SAaron DeBattista // Perform an elementwise mul over the reshaped input and weight. 12469c984b6SRob Suderman llvm::SmallVector<int64_t, 2> mulShape{ 12569c984b6SRob Suderman inputType.getDimSize(0), inputType.getDimSize(1), 12669c984b6SRob Suderman inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]}; 127dfd07082SAaron DeBattista auto mulShapeType = RankedTensorType::get( 128dfd07082SAaron DeBattista mulShape, 1295550c821STres Popp dyn_cast<RankedTensorType>(weight.getType()).getElementType()); 130e0537d1aSTai Ly 131e0537d1aSTai Ly if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) { 132e0537d1aSTai Ly return failure(); 133e0537d1aSTai Ly } 134e0537d1aSTai Ly 135*a58e774fSJack Frankland auto shiftElementType = IntegerType::get(rewriter.getContext(), 8); 136*a58e774fSJack Frankland auto shiftType = RankedTensorType::get({1}, shiftElementType); 137*a58e774fSJack Frankland auto shiftZeroAttr = DenseElementsAttr::get( 138*a58e774fSJack Frankland shiftType, rewriter.getIntegerAttr(shiftElementType, 0)); 139*a58e774fSJack Frankland Value constZero = 140*a58e774fSJack Frankland rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr); 14169c984b6SRob Suderman Value mulValue = rewriter 14269c984b6SRob Suderman .create<tosa::MulOp>(op.getLoc(), mulShapeType, input, 143*a58e774fSJack Frankland weight, constZero) 144dfd07082SAaron DeBattista .getResult(); 145dfd07082SAaron DeBattista 146dfd07082SAaron DeBattista // Reshape output to [N, H, W, C * M]. 1475550c821STres Popp auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape(); 148dfd07082SAaron DeBattista auto outputShapeType = RankedTensorType::get( 149dfd07082SAaron DeBattista outputShape, 1505550c821STres Popp dyn_cast<RankedTensorType>(input.getType()).getElementType()); 151e0537d1aSTai Ly Value outputValue = rewriter.create<tosa::ReshapeOp>( 1529e1a3441SAlexander Shaposhnikov op.getLoc(), outputShapeType, mulValue, 1539e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(outputShape)); 154dfd07082SAaron DeBattista 155e0537d1aSTai Ly Value bias = op.getBias(); 156e0537d1aSTai Ly if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) { 157e0537d1aSTai Ly return failure(); 158e0537d1aSTai Ly } 159e0537d1aSTai Ly 160dfd07082SAaron DeBattista // Add in the bias. 161dfd07082SAaron DeBattista rewriter 162e0537d1aSTai Ly .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias) 163dfd07082SAaron DeBattista .getResult(); 164dfd07082SAaron DeBattista return success(); 165dfd07082SAaron DeBattista } 166dfd07082SAaron DeBattista }; 167dfd07082SAaron DeBattista 168dfd07082SAaron DeBattista } // namespace 169dfd07082SAaron DeBattista 170dfd07082SAaron DeBattista void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, 171dfd07082SAaron DeBattista RewritePatternSet &patterns) { 172b4e0507cSTres Popp patterns.add<DepthwiseConv2DIsMul>(ctx); 173dfd07082SAaron DeBattista } 174