1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically 10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 16 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 17 18 using namespace mlir; 19 using namespace mlir::tosa; 20 21 namespace { 22 23 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) { 24 return to_vector(llvm::map_range(shape, [](int64_t dim) { 25 return ShapedType::isDynamic(dim) ? -1 : dim; 26 })); 27 } 28 29 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { 30 explicit Conv2DIsFullyConnected(MLIRContext *context) 31 : OpRewritePattern(context) {} 32 33 LogicalResult matchAndRewrite(tosa::Conv2DOp op, 34 PatternRewriter &rewriter) const override { 35 Value input = op.getInput(); 36 Value weight = op.getWeight(); 37 ShapedType inputType = cast<ShapedType>(input.getType()); 38 ShapedType weightType = cast<ShapedType>(weight.getType()); 39 ShapedType resultType = cast<ShapedType>(op.getType()); 40 41 auto numDynamic = 42 llvm::count_if(inputType.getShape(), ShapedType::isDynamic); 43 if (numDynamic > 1) 44 return rewriter.notifyMatchFailure( 45 op, "at most one dim in input may be dynamic"); 46 if (!weightType.hasRank()) 47 return rewriter.notifyMatchFailure(op, "unranked weight input"); 48 49 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) 50 return failure(); 51 52 // Only works for a 1x1 kernel. 53 ArrayRef<int64_t> weightShape = weightType.getShape(); 54 if (weightShape[1] != 1 || weightShape[2] != 1) 55 return failure(); 56 57 llvm::ArrayRef<int64_t> padAttr = op.getPad(); 58 llvm::SmallVector<int64_t> pad(8, 0); 59 for (const auto &it : llvm::enumerate(padAttr)) 60 pad[it.index() + 2] = it.value(); 61 62 if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { 63 Type inputETy = inputType.getElementType(); 64 Attribute zeroAttr = rewriter.getZeroAttr(inputETy); 65 if (op.getQuantizationInfo()) { 66 auto quantizationInfo = op.getQuantizationInfo(); 67 int64_t iZp = quantizationInfo->getInputZp(); 68 69 if (!validIntegerRange(cast<IntegerType>(inputETy), iZp)) 70 return rewriter.notifyMatchFailure( 71 op, "tosa.conv op quantization has zp outside of input range"); 72 73 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); 74 } 75 76 llvm::SmallVector<int64_t> newShape(inputType.getShape()); 77 78 for (int i = 0, s = newShape.size(); i < s; ++i) { 79 if (newShape[i] != ShapedType::kDynamic) { 80 newShape[i] += pad[i * 2] + pad[i * 2 + 1]; 81 } 82 } 83 84 Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad); 85 86 auto padTy = RankedTensorType::get({}, inputETy); 87 auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); 88 Value padVal = 89 rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr); 90 inputType = RankedTensorType::get(newShape, inputETy); 91 input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input, 92 padSizeVal, padVal); 93 } 94 95 // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. 96 ArrayRef<int64_t> inputShape = inputType.getShape(); 97 int64_t combined = ShapedType::kDynamic; 98 if (numDynamic == 0) 99 combined = inputShape[0] * inputShape[1] * inputShape[2]; 100 llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]}; 101 auto revisedInputShapeType = 102 RankedTensorType::get(revisedInputShape, inputType.getElementType()); 103 auto reshapedInput = rewriter 104 .create<tosa::ReshapeOp>( 105 op.getLoc(), revisedInputShapeType, input, 106 rewriter.getDenseI64ArrayAttr( 107 convertFromMlirShape(revisedInputShape))) 108 .getResult(); 109 110 // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. 111 llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0], 112 weightShape[3]}; 113 auto revisedWeightShapeType = RankedTensorType::get( 114 revisedWeightShape, 115 dyn_cast<RankedTensorType>(weight.getType()).getElementType()); 116 auto reshapedWeight = rewriter 117 .create<tosa::ReshapeOp>( 118 op.getLoc(), revisedWeightShapeType, weight, 119 rewriter.getDenseI64ArrayAttr( 120 convertFromMlirShape(revisedWeightShape))) 121 .getResult(); 122 123 // Perform a fully connected network over the reshaped input and weight. 124 llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]}; 125 auto fullyConnectedShapeType = 126 RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); 127 128 Value fullyConnectedValue; 129 if (op.getQuantizationInfo()) { 130 fullyConnectedValue = 131 rewriter 132 .create<tosa::FullyConnectedOp>( 133 op.getLoc(), fullyConnectedShapeType, reshapedInput, 134 reshapedWeight, op.getBias(), *op.getQuantizationInfo()) 135 .getResult(); 136 } else { 137 fullyConnectedValue = rewriter 138 .create<tosa::FullyConnectedOp>( 139 op.getLoc(), fullyConnectedShapeType, 140 reshapedInput, reshapedWeight, op.getBias()) 141 .getResult(); 142 } 143 144 // Reshape output to [N, IH, IW, OC]. 145 llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1], 146 inputShape[2], weightShape[0]}; 147 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 148 op, resultType, fullyConnectedValue, 149 rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape))); 150 return success(); 151 } 152 }; 153 154 } // namespace 155 156 void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, 157 RewritePatternSet &patterns) { 158 patterns.add<Conv2DIsFullyConnected>(ctx); 159 } 160