13bcaf2ebSGeorgios Pinitas //===- TosaDecomposeConv2D.cpp --------------------------------------------===// 2dfd07082SAaron DeBattista // 3dfd07082SAaron DeBattista // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4dfd07082SAaron DeBattista // See https://llvm.org/LICENSE.txt for license information. 5dfd07082SAaron DeBattista // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6dfd07082SAaron DeBattista // 7dfd07082SAaron DeBattista //===----------------------------------------------------------------------===// 8dfd07082SAaron DeBattista // 9dfd07082SAaron DeBattista // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically 10dfd07082SAaron DeBattista // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape 11dfd07082SAaron DeBattista // 12dfd07082SAaron DeBattista //===----------------------------------------------------------------------===// 13dfd07082SAaron DeBattista 14dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/IR/TosaOps.h" 15dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/Transforms/Passes.h" 1669c984b6SRob Suderman #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 17dfd07082SAaron DeBattista 18dfd07082SAaron DeBattista using namespace mlir; 19dfd07082SAaron DeBattista using namespace mlir::tosa; 20dfd07082SAaron DeBattista 21dfd07082SAaron DeBattista namespace { 22dfd07082SAaron DeBattista 234ecfdf8aSMehdi Amini SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) { 24fb4cedccSAliia Khasanova return to_vector(llvm::map_range(shape, [](int64_t dim) { 25fb4cedccSAliia Khasanova return ShapedType::isDynamic(dim) ? -1 : dim; 26fb4cedccSAliia Khasanova })); 27fb4cedccSAliia Khasanova } 28fb4cedccSAliia Khasanova 29dfd07082SAaron DeBattista struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { 30dfd07082SAaron DeBattista explicit Conv2DIsFullyConnected(MLIRContext *context) 31dfd07082SAaron DeBattista : OpRewritePattern(context) {} 32dfd07082SAaron DeBattista 33dfd07082SAaron DeBattista LogicalResult matchAndRewrite(tosa::Conv2DOp op, 34dfd07082SAaron DeBattista PatternRewriter &rewriter) const override { 3513448db0SJacques Pienaar Value input = op.getInput(); 3613448db0SJacques Pienaar Value weight = op.getWeight(); 375550c821STres Popp ShapedType inputType = cast<ShapedType>(input.getType()); 385550c821STres Popp ShapedType weightType = cast<ShapedType>(weight.getType()); 395550c821STres Popp ShapedType resultType = cast<ShapedType>(op.getType()); 40dfd07082SAaron DeBattista 41380a1b20SKazu Hirata auto numDynamic = 42380a1b20SKazu Hirata llvm::count_if(inputType.getShape(), ShapedType::isDynamic); 43e08a991fSJacques Pienaar if (numDynamic > 1) 44e08a991fSJacques Pienaar return rewriter.notifyMatchFailure( 45e08a991fSJacques Pienaar op, "at most one dim in input may be dynamic"); 46e08a991fSJacques Pienaar if (!weightType.hasRank()) 47e08a991fSJacques Pienaar return rewriter.notifyMatchFailure(op, "unranked weight input"); 48dfd07082SAaron DeBattista 4911030c7dSAlexander Shaposhnikov if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; })) 50dfd07082SAaron DeBattista return failure(); 51dfd07082SAaron DeBattista 52dfd07082SAaron DeBattista // Only works for a 1x1 kernel. 53dfd07082SAaron DeBattista ArrayRef<int64_t> weightShape = weightType.getShape(); 54e08a991fSJacques Pienaar if (weightShape[1] != 1 || weightShape[2] != 1) 55dfd07082SAaron DeBattista return failure(); 56dfd07082SAaron DeBattista 5711030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> padAttr = op.getPad(); 5869c984b6SRob Suderman llvm::SmallVector<int64_t> pad(8, 0); 5911030c7dSAlexander Shaposhnikov for (const auto &it : llvm::enumerate(padAttr)) 6011030c7dSAlexander Shaposhnikov pad[it.index() + 2] = it.value(); 6169c984b6SRob Suderman 6269c984b6SRob Suderman if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { 6369c984b6SRob Suderman Type inputETy = inputType.getElementType(); 6469c984b6SRob Suderman Attribute zeroAttr = rewriter.getZeroAttr(inputETy); 6569c984b6SRob Suderman if (op.getQuantizationInfo()) { 6669c984b6SRob Suderman auto quantizationInfo = op.getQuantizationInfo(); 6769c984b6SRob Suderman int64_t iZp = quantizationInfo->getInputZp(); 6869c984b6SRob Suderman 695550c821STres Popp if (!validIntegerRange(cast<IntegerType>(inputETy), iZp)) 7069c984b6SRob Suderman return rewriter.notifyMatchFailure( 7169c984b6SRob Suderman op, "tosa.conv op quantization has zp outside of input range"); 7269c984b6SRob Suderman 7369c984b6SRob Suderman zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); 7469c984b6SRob Suderman } 7569c984b6SRob Suderman 7669c984b6SRob Suderman llvm::SmallVector<int64_t> newShape(inputType.getShape()); 7769c984b6SRob Suderman 7869c984b6SRob Suderman for (int i = 0, s = newShape.size(); i < s; ++i) { 7969c984b6SRob Suderman if (newShape[i] != ShapedType::kDynamic) { 8069c984b6SRob Suderman newShape[i] += pad[i * 2] + pad[i * 2 + 1]; 8169c984b6SRob Suderman } 8269c984b6SRob Suderman } 8369c984b6SRob Suderman 84*7e622b61SJerry-Ge Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad); 8569c984b6SRob Suderman 8669c984b6SRob Suderman auto padTy = RankedTensorType::get({}, inputETy); 8769c984b6SRob Suderman auto padAttr = DenseElementsAttr::get(padTy, zeroAttr); 8869c984b6SRob Suderman Value padVal = 8969c984b6SRob Suderman rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr); 9069c984b6SRob Suderman inputType = RankedTensorType::get(newShape, inputETy); 9169c984b6SRob Suderman input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input, 9269c984b6SRob Suderman padSizeVal, padVal); 9369c984b6SRob Suderman } 9469c984b6SRob Suderman 95dfd07082SAaron DeBattista // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. 96dfd07082SAaron DeBattista ArrayRef<int64_t> inputShape = inputType.getShape(); 97399638f9SAliia Khasanova int64_t combined = ShapedType::kDynamic; 98fb4cedccSAliia Khasanova if (numDynamic == 0) 99fb4cedccSAliia Khasanova combined = inputShape[0] * inputShape[1] * inputShape[2]; 100e08a991fSJacques Pienaar llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]}; 101e08a991fSJacques Pienaar auto revisedInputShapeType = 102e08a991fSJacques Pienaar RankedTensorType::get(revisedInputShape, inputType.getElementType()); 103dfd07082SAaron DeBattista auto reshapedInput = rewriter 104dfd07082SAaron DeBattista .create<tosa::ReshapeOp>( 105dfd07082SAaron DeBattista op.getLoc(), revisedInputShapeType, input, 1069e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr( 1074ecfdf8aSMehdi Amini convertFromMlirShape(revisedInputShape))) 108dfd07082SAaron DeBattista .getResult(); 109dfd07082SAaron DeBattista 110dfd07082SAaron DeBattista // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. 111dfd07082SAaron DeBattista llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0], 112dfd07082SAaron DeBattista weightShape[3]}; 113dfd07082SAaron DeBattista auto revisedWeightShapeType = RankedTensorType::get( 114dfd07082SAaron DeBattista revisedWeightShape, 1155550c821STres Popp dyn_cast<RankedTensorType>(weight.getType()).getElementType()); 116dfd07082SAaron DeBattista auto reshapedWeight = rewriter 117dfd07082SAaron DeBattista .create<tosa::ReshapeOp>( 118dfd07082SAaron DeBattista op.getLoc(), revisedWeightShapeType, weight, 1199e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr( 1204ecfdf8aSMehdi Amini convertFromMlirShape(revisedWeightShape))) 121dfd07082SAaron DeBattista .getResult(); 122dfd07082SAaron DeBattista 123dfd07082SAaron DeBattista // Perform a fully connected network over the reshaped input and weight. 124e08a991fSJacques Pienaar llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]}; 125e08a991fSJacques Pienaar auto fullyConnectedShapeType = 126e08a991fSJacques Pienaar RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); 127dfd07082SAaron DeBattista 128dfd07082SAaron DeBattista Value fullyConnectedValue; 12913448db0SJacques Pienaar if (op.getQuantizationInfo()) { 130dfd07082SAaron DeBattista fullyConnectedValue = 131dfd07082SAaron DeBattista rewriter 132dfd07082SAaron DeBattista .create<tosa::FullyConnectedOp>( 133dfd07082SAaron DeBattista op.getLoc(), fullyConnectedShapeType, reshapedInput, 13413448db0SJacques Pienaar reshapedWeight, op.getBias(), *op.getQuantizationInfo()) 135dfd07082SAaron DeBattista .getResult(); 136dfd07082SAaron DeBattista } else { 137dfd07082SAaron DeBattista fullyConnectedValue = rewriter 138dfd07082SAaron DeBattista .create<tosa::FullyConnectedOp>( 139dfd07082SAaron DeBattista op.getLoc(), fullyConnectedShapeType, 14013448db0SJacques Pienaar reshapedInput, reshapedWeight, op.getBias()) 141dfd07082SAaron DeBattista .getResult(); 142dfd07082SAaron DeBattista } 143dfd07082SAaron DeBattista 144dfd07082SAaron DeBattista // Reshape output to [N, IH, IW, OC]. 145dfd07082SAaron DeBattista llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1], 146dfd07082SAaron DeBattista inputShape[2], weightShape[0]}; 147dfd07082SAaron DeBattista rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 148dfd07082SAaron DeBattista op, resultType, fullyConnectedValue, 1499e1a3441SAlexander Shaposhnikov rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape))); 150dfd07082SAaron DeBattista return success(); 151dfd07082SAaron DeBattista } 152dfd07082SAaron DeBattista }; 153dfd07082SAaron DeBattista 154dfd07082SAaron DeBattista } // namespace 155dfd07082SAaron DeBattista 156dfd07082SAaron DeBattista void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, 157dfd07082SAaron DeBattista RewritePatternSet &patterns) { 158b4e0507cSTres Popp patterns.add<Conv2DIsFullyConnected>(ctx); 159dfd07082SAaron DeBattista } 160