xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (revision 7e622b61320543b3706711609f1f32fd9ea3788d)
13bcaf2ebSGeorgios Pinitas //===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2dfd07082SAaron DeBattista //
3dfd07082SAaron DeBattista // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4dfd07082SAaron DeBattista // See https://llvm.org/LICENSE.txt for license information.
5dfd07082SAaron DeBattista // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6dfd07082SAaron DeBattista //
7dfd07082SAaron DeBattista //===----------------------------------------------------------------------===//
8dfd07082SAaron DeBattista //
9dfd07082SAaron DeBattista // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10dfd07082SAaron DeBattista // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11dfd07082SAaron DeBattista //
12dfd07082SAaron DeBattista //===----------------------------------------------------------------------===//
13dfd07082SAaron DeBattista 
14dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15dfd07082SAaron DeBattista #include "mlir/Dialect/Tosa/Transforms/Passes.h"
1669c984b6SRob Suderman #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
17dfd07082SAaron DeBattista 
18dfd07082SAaron DeBattista using namespace mlir;
19dfd07082SAaron DeBattista using namespace mlir::tosa;
20dfd07082SAaron DeBattista 
21dfd07082SAaron DeBattista namespace {
22dfd07082SAaron DeBattista 
234ecfdf8aSMehdi Amini SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
24fb4cedccSAliia Khasanova   return to_vector(llvm::map_range(shape, [](int64_t dim) {
25fb4cedccSAliia Khasanova     return ShapedType::isDynamic(dim) ? -1 : dim;
26fb4cedccSAliia Khasanova   }));
27fb4cedccSAliia Khasanova }
28fb4cedccSAliia Khasanova 
29dfd07082SAaron DeBattista struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
30dfd07082SAaron DeBattista   explicit Conv2DIsFullyConnected(MLIRContext *context)
31dfd07082SAaron DeBattista       : OpRewritePattern(context) {}
32dfd07082SAaron DeBattista 
33dfd07082SAaron DeBattista   LogicalResult matchAndRewrite(tosa::Conv2DOp op,
34dfd07082SAaron DeBattista                                 PatternRewriter &rewriter) const override {
3513448db0SJacques Pienaar     Value input = op.getInput();
3613448db0SJacques Pienaar     Value weight = op.getWeight();
375550c821STres Popp     ShapedType inputType = cast<ShapedType>(input.getType());
385550c821STres Popp     ShapedType weightType = cast<ShapedType>(weight.getType());
395550c821STres Popp     ShapedType resultType = cast<ShapedType>(op.getType());
40dfd07082SAaron DeBattista 
41380a1b20SKazu Hirata     auto numDynamic =
42380a1b20SKazu Hirata         llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
43e08a991fSJacques Pienaar     if (numDynamic > 1)
44e08a991fSJacques Pienaar       return rewriter.notifyMatchFailure(
45e08a991fSJacques Pienaar           op, "at most one dim in input may be dynamic");
46e08a991fSJacques Pienaar     if (!weightType.hasRank())
47e08a991fSJacques Pienaar       return rewriter.notifyMatchFailure(op, "unranked weight input");
48dfd07082SAaron DeBattista 
4911030c7dSAlexander Shaposhnikov     if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
50dfd07082SAaron DeBattista       return failure();
51dfd07082SAaron DeBattista 
52dfd07082SAaron DeBattista     // Only works for a 1x1 kernel.
53dfd07082SAaron DeBattista     ArrayRef<int64_t> weightShape = weightType.getShape();
54e08a991fSJacques Pienaar     if (weightShape[1] != 1 || weightShape[2] != 1)
55dfd07082SAaron DeBattista       return failure();
56dfd07082SAaron DeBattista 
5711030c7dSAlexander Shaposhnikov     llvm::ArrayRef<int64_t> padAttr = op.getPad();
5869c984b6SRob Suderman     llvm::SmallVector<int64_t> pad(8, 0);
5911030c7dSAlexander Shaposhnikov     for (const auto &it : llvm::enumerate(padAttr))
6011030c7dSAlexander Shaposhnikov       pad[it.index() + 2] = it.value();
6169c984b6SRob Suderman 
6269c984b6SRob Suderman     if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
6369c984b6SRob Suderman       Type inputETy = inputType.getElementType();
6469c984b6SRob Suderman       Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
6569c984b6SRob Suderman       if (op.getQuantizationInfo()) {
6669c984b6SRob Suderman         auto quantizationInfo = op.getQuantizationInfo();
6769c984b6SRob Suderman         int64_t iZp = quantizationInfo->getInputZp();
6869c984b6SRob Suderman 
695550c821STres Popp         if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
7069c984b6SRob Suderman           return rewriter.notifyMatchFailure(
7169c984b6SRob Suderman               op, "tosa.conv op quantization has zp outside of input range");
7269c984b6SRob Suderman 
7369c984b6SRob Suderman         zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
7469c984b6SRob Suderman       }
7569c984b6SRob Suderman 
7669c984b6SRob Suderman       llvm::SmallVector<int64_t> newShape(inputType.getShape());
7769c984b6SRob Suderman 
7869c984b6SRob Suderman       for (int i = 0, s = newShape.size(); i < s; ++i) {
7969c984b6SRob Suderman         if (newShape[i] != ShapedType::kDynamic) {
8069c984b6SRob Suderman           newShape[i] += pad[i * 2] + pad[i * 2 + 1];
8169c984b6SRob Suderman         }
8269c984b6SRob Suderman       }
8369c984b6SRob Suderman 
84*7e622b61SJerry-Ge       Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
8569c984b6SRob Suderman 
8669c984b6SRob Suderman       auto padTy = RankedTensorType::get({}, inputETy);
8769c984b6SRob Suderman       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
8869c984b6SRob Suderman       Value padVal =
8969c984b6SRob Suderman           rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
9069c984b6SRob Suderman       inputType = RankedTensorType::get(newShape, inputETy);
9169c984b6SRob Suderman       input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
9269c984b6SRob Suderman                                            padSizeVal, padVal);
9369c984b6SRob Suderman     }
9469c984b6SRob Suderman 
95dfd07082SAaron DeBattista     // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
96dfd07082SAaron DeBattista     ArrayRef<int64_t> inputShape = inputType.getShape();
97399638f9SAliia Khasanova     int64_t combined = ShapedType::kDynamic;
98fb4cedccSAliia Khasanova     if (numDynamic == 0)
99fb4cedccSAliia Khasanova       combined = inputShape[0] * inputShape[1] * inputShape[2];
100e08a991fSJacques Pienaar     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
101e08a991fSJacques Pienaar     auto revisedInputShapeType =
102e08a991fSJacques Pienaar         RankedTensorType::get(revisedInputShape, inputType.getElementType());
103dfd07082SAaron DeBattista     auto reshapedInput = rewriter
104dfd07082SAaron DeBattista                              .create<tosa::ReshapeOp>(
105dfd07082SAaron DeBattista                                  op.getLoc(), revisedInputShapeType, input,
1069e1a3441SAlexander Shaposhnikov                                  rewriter.getDenseI64ArrayAttr(
1074ecfdf8aSMehdi Amini                                      convertFromMlirShape(revisedInputShape)))
108dfd07082SAaron DeBattista                              .getResult();
109dfd07082SAaron DeBattista 
110dfd07082SAaron DeBattista     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
111dfd07082SAaron DeBattista     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
112dfd07082SAaron DeBattista                                                      weightShape[3]};
113dfd07082SAaron DeBattista     auto revisedWeightShapeType = RankedTensorType::get(
114dfd07082SAaron DeBattista         revisedWeightShape,
1155550c821STres Popp         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
116dfd07082SAaron DeBattista     auto reshapedWeight = rewriter
117dfd07082SAaron DeBattista                               .create<tosa::ReshapeOp>(
118dfd07082SAaron DeBattista                                   op.getLoc(), revisedWeightShapeType, weight,
1199e1a3441SAlexander Shaposhnikov                                   rewriter.getDenseI64ArrayAttr(
1204ecfdf8aSMehdi Amini                                       convertFromMlirShape(revisedWeightShape)))
121dfd07082SAaron DeBattista                               .getResult();
122dfd07082SAaron DeBattista 
123dfd07082SAaron DeBattista     // Perform a fully connected network over the reshaped input and weight.
124e08a991fSJacques Pienaar     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
125e08a991fSJacques Pienaar     auto fullyConnectedShapeType =
126e08a991fSJacques Pienaar         RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
127dfd07082SAaron DeBattista 
128dfd07082SAaron DeBattista     Value fullyConnectedValue;
12913448db0SJacques Pienaar     if (op.getQuantizationInfo()) {
130dfd07082SAaron DeBattista       fullyConnectedValue =
131dfd07082SAaron DeBattista           rewriter
132dfd07082SAaron DeBattista               .create<tosa::FullyConnectedOp>(
133dfd07082SAaron DeBattista                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
13413448db0SJacques Pienaar                   reshapedWeight, op.getBias(), *op.getQuantizationInfo())
135dfd07082SAaron DeBattista               .getResult();
136dfd07082SAaron DeBattista     } else {
137dfd07082SAaron DeBattista       fullyConnectedValue = rewriter
138dfd07082SAaron DeBattista                                 .create<tosa::FullyConnectedOp>(
139dfd07082SAaron DeBattista                                     op.getLoc(), fullyConnectedShapeType,
14013448db0SJacques Pienaar                                     reshapedInput, reshapedWeight, op.getBias())
141dfd07082SAaron DeBattista                                 .getResult();
142dfd07082SAaron DeBattista     }
143dfd07082SAaron DeBattista 
144dfd07082SAaron DeBattista     // Reshape output to [N, IH, IW, OC].
145dfd07082SAaron DeBattista     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
146dfd07082SAaron DeBattista                                               inputShape[2], weightShape[0]};
147dfd07082SAaron DeBattista     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
148dfd07082SAaron DeBattista         op, resultType, fullyConnectedValue,
1499e1a3441SAlexander Shaposhnikov         rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
150dfd07082SAaron DeBattista     return success();
151dfd07082SAaron DeBattista   }
152dfd07082SAaron DeBattista };
153dfd07082SAaron DeBattista 
154dfd07082SAaron DeBattista } // namespace
155dfd07082SAaron DeBattista 
156dfd07082SAaron DeBattista void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
157dfd07082SAaron DeBattista                                              RewritePatternSet &patterns) {
158b4e0507cSTres Popp   patterns.add<Conv2DIsFullyConnected>(ctx);
159dfd07082SAaron DeBattista }
160