1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically 10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace mlir; 21 using namespace mlir::tosa; 22 23 namespace { 24 25 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { 26 explicit DepthwiseConv2DIsMul(MLIRContext *context) 27 : OpRewritePattern(context) {} 28 29 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, 30 PatternRewriter &rewriter) const override { 31 Value input = op.getInput(); 32 Value weight = op.getWeight(); 33 ShapedType inputType = cast<ShapedType>(input.getType()); 34 ShapedType weightType = cast<ShapedType>(weight.getType()); 35 ShapedType resultType = cast<ShapedType>(op.getOutput().getType()); 36 37 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && 38 resultType.hasStaticShape())) { 39 return failure(); 40 } 41 42 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) 43 return failure(); 44 45 // Only works for a 1x1 kernel. 46 ArrayRef<int64_t> weightShape = weightType.getShape(); 47 if (weightShape[0] != 1 || weightShape[1] != 1) { 48 return failure(); 49 } 50 51 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. 52 ArrayRef<int64_t> inputShape = inputType.getShape(); 53 llvm::SmallVector<int64_t, 2> revisedInputShape{ 54 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; 55 inputType = RankedTensorType::get( 56 revisedInputShape, 57 dyn_cast<RankedTensorType>(input.getType()).getElementType()); 58 input = rewriter 59 .create<tosa::ReshapeOp>( 60 op.getLoc(), inputType, input, 61 rewriter.getDenseI64ArrayAttr(revisedInputShape)) 62 .getResult(); 63 64 if (inputType.getElementType() != resultType.getElementType()) { 65 inputType = inputType.clone(resultType.getElementType()); 66 input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input); 67 } 68 69 if (weightType.getElementType() != resultType.getElementType()) { 70 weightType = weightType.clone(resultType.getElementType()); 71 weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight); 72 } 73 74 if (auto quantizationInfo = op.getQuantizationInfo()) { 75 auto iZp = quantizationInfo->getInputZp(); 76 auto wZp = quantizationInfo->getWeightZp(); 77 78 auto applyZp = [&](Value val, int64_t zp) -> Value { 79 if (zp == 0) 80 return val; 81 auto ety = cast<ShapedType>(val.getType()).getElementType(); 82 std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(), 83 1); 84 auto zpTy = RankedTensorType::get(shape, ety); 85 auto zpAttr = 86 DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); 87 auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr); 88 return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val, 89 zpVal); 90 }; 91 92 input = applyZp(input, iZp); 93 weight = applyZp(weight, wZp); 94 } 95 96 ArrayRef<int64_t> padAttr = op.getPad(); 97 llvm::SmallVector<int64_t> pad(10, 0); 98 for (const auto &it : llvm::enumerate(padAttr)) 99 pad[it.index() + 2] = it.value(); 100 101 if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { 102 Type inputETy = inputType.getElementType(); 103 Attribute zeroAttr = rewriter.getZeroAttr(inputETy); 104 105 llvm::SmallVector<int64_t> newShape(inputType.getShape()); 106 for (int i = 0, s = pad.size(); i < s; ++i) { 107 if (newShape[i / 2] != ShapedType::kDynamic) { 108 newShape[i / 2] += pad[i]; 109 } 110 } 111 112 Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad); 113 114 auto padTy = RankedTensorType::get({}, inputETy); 115 auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); 116 Value padVal = 117 rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr); 118 inputType = RankedTensorType::get(newShape, inputETy); 119 input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input, 120 padSizeVal, padVal); 121 } 122 123 // Perform an elementwise mul over the reshaped input and weight. 124 llvm::SmallVector<int64_t, 2> mulShape{ 125 inputType.getDimSize(0), inputType.getDimSize(1), 126 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]}; 127 auto mulShapeType = RankedTensorType::get( 128 mulShape, 129 dyn_cast<RankedTensorType>(weight.getType()).getElementType()); 130 131 if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) { 132 return failure(); 133 } 134 135 auto shiftElementType = IntegerType::get(rewriter.getContext(), 8); 136 auto shiftType = RankedTensorType::get({1}, shiftElementType); 137 auto shiftZeroAttr = DenseElementsAttr::get( 138 shiftType, rewriter.getIntegerAttr(shiftElementType, 0)); 139 Value constZero = 140 rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr); 141 Value mulValue = rewriter 142 .create<tosa::MulOp>(op.getLoc(), mulShapeType, input, 143 weight, constZero) 144 .getResult(); 145 146 // Reshape output to [N, H, W, C * M]. 147 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape(); 148 auto outputShapeType = RankedTensorType::get( 149 outputShape, 150 dyn_cast<RankedTensorType>(input.getType()).getElementType()); 151 Value outputValue = rewriter.create<tosa::ReshapeOp>( 152 op.getLoc(), outputShapeType, mulValue, 153 rewriter.getDenseI64ArrayAttr(outputShape)); 154 155 Value bias = op.getBias(); 156 if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) { 157 return failure(); 158 } 159 160 // Add in the bias. 161 rewriter 162 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias) 163 .getResult(); 164 return success(); 165 } 166 }; 167 168 } // namespace 169 170 void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, 171 RewritePatternSet &patterns) { 172 patterns.add<DepthwiseConv2DIsMul>(ctx); 173 } 174