xref: /llvm-project/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (revision 360a03c980e3e96ac53746b118a04305a28a5310)
1 //===- TosaTestPasses.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Test passes to exercise TOSA helper functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #define PASS_NAME "tosa-test-quant-utils"
24 
25 using namespace mlir;
26 using namespace mlir::tosa;
27 
28 // This transformation converts quantized uint8 to quantized int8. The
29 // construction of the new type invokes buildQTypeFromMinMax. Extracted from
30 // TOSA legalization infrastructure.
31 struct ConvertTosaNegateOp : public RewritePattern {
32   explicit ConvertTosaNegateOp(MLIRContext *context)
33       : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
34   LogicalResult matchAndRewrite(Operation *op,
35                                 PatternRewriter &rewriter) const override;
36 };
37 
38 LogicalResult
39 ConvertTosaNegateOp::matchAndRewrite(Operation *op,
40                                      PatternRewriter &rewriter) const {
41 
42   auto tosaNegateOp = cast<tosa::NegateOp>(op);
43 
44   auto inputType =
45       dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType());
46   // skip if input is not ranked tensor type
47   if (!inputType)
48     return failure();
49 
50   // skip if it's not ranked tensor type.
51   auto outputType =
52       dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType());
53   if (!outputType)
54     return failure();
55 
56   // skip if output is not per-tensor quantized type.
57   auto outputElementType =
58       dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
59   if (!outputElementType)
60     return failure();
61 
62   // skip if output is not uint8.
63   if (outputElementType.isSigned() ||
64       outputElementType.getStorageTypeIntegralWidth() != 8)
65     return failure();
66 
67   double typeRangeMin = double(outputElementType.getStorageTypeMin() -
68                                outputElementType.getZeroPoint()) *
69                         outputElementType.getScale();
70   double typeRangeMax = double(outputElementType.getStorageTypeMax() -
71                                outputElementType.getZeroPoint()) *
72                         outputElementType.getScale();
73   bool narrowRange = outputElementType.getStorageTypeMin() == 1;
74 
75   auto dstQConstType = RankedTensorType::get(
76       outputType.getShape(),
77       buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
78                            rewriter.getF64FloatAttr(typeRangeMin),
79                            rewriter.getF64FloatAttr(typeRangeMax),
80                            rewriter.getI32IntegerAttr(
81                                outputElementType.getStorageTypeIntegralWidth()),
82                            0, true /* signed */,
83                            rewriter.getBoolAttr(narrowRange)));
84 
85   ElementsAttr inputElems;
86   if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems)))
87     return failure();
88 
89   auto newConstOp =
90       rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
91   auto newNegateOp = rewriter.create<tosa::NegateOp>(
92       op->getLoc(), dstQConstType, newConstOp.getResult());
93 
94   rewriter.replaceOp(op, {newNegateOp.getResult()});
95   return success();
96 }
97 
98 // This transformation modifies the quantized output of a test conv2d input and
99 // appends a TOSA rescale after it. The rescale op requires the invocation of
100 // computeMultiplierAndShift. From TOSA legalization infrastructure.
101 struct ConvertTosaConv2DOp : public RewritePattern {
102   explicit ConvertTosaConv2DOp(MLIRContext *context)
103       : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
104   LogicalResult matchAndRewrite(Operation *op,
105                                 PatternRewriter &rewriter) const override;
106 };
107 
108 LogicalResult
109 ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
110                                      PatternRewriter &rewriter) const {
111 
112   auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
113 
114   auto inputType =
115       dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType());
116 
117   // skip if input is not ranked tensor type
118   if (!inputType)
119     return failure();
120 
121   auto weightType =
122       dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType());
123 
124   // skip if wt is not ranked tensor type
125   if (!weightType)
126     return failure();
127 
128   // skip if it's not ranked tensor type.
129   auto outputType =
130       dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType());
131   if (!outputType)
132     return failure();
133 
134   auto inputQType =
135       dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType());
136   auto weightQType =
137       dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType());
138   auto outputQType =
139       dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
140 
141   // Works on quantized type only.
142   if (!(inputQType && weightQType && outputQType))
143     return failure();
144 
145   auto newTosaConv2DOpType =
146       RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
147 
148   auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
149       op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
150       tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(),
151       tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(),
152       tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr());
153 
154   // Create rescale to quantized type
155   double inputScale = inputQType.getScale();
156   double weightScale = weightQType.getScale();
157   double outputScale = outputQType.getScale();
158   int64_t outputZp = outputQType.getZeroPoint();
159 
160   double opTensorScale = (inputScale * weightScale) / outputScale;
161 
162   int32_t multiplier;
163   int32_t shift;
164 
165   // Obtain the quantized scale = multiplier and shift.
166   computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);
167 
168   auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
169       op->getLoc(), outputType, newTosaConv2DOp.getResult(),
170       rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
171       rewriter.getDenseI32ArrayAttr({multiplier}),
172       rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
173       rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
174       rewriter.getBoolAttr(false));
175 
176   rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
177   return success();
178 }
179 
180 namespace {
181 
182 struct TosaTestQuantUtilAPI
183     : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> {
184   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI)
185 
186   StringRef getArgument() const final { return PASS_NAME; }
187   StringRef getDescription() const final {
188     return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
189   }
190   void runOnOperation() override;
191 };
192 
193 void TosaTestQuantUtilAPI::runOnOperation() {
194   auto *ctx = &getContext();
195   RewritePatternSet patterns(ctx);
196   auto func = getOperation();
197 
198   patterns.add<ConvertTosaNegateOp>(ctx);
199   patterns.add<ConvertTosaConv2DOp>(ctx);
200   (void)applyPatternsGreedily(func, std::move(patterns));
201 }
202 
203 } // namespace
204 
205 namespace mlir {
206 void registerTosaTestQuantUtilAPIPass() {
207   PassRegistration<TosaTestQuantUtilAPI>();
208 }
209 } // namespace mlir
210