1 //===- TosaTestPasses.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Test passes to exercise TOSA helper functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 #define PASS_NAME "tosa-test-quant-utils" 24 25 using namespace mlir; 26 using namespace mlir::tosa; 27 28 // This transformation converts quantized uint8 to quantized int8. The 29 // construction of the new type invokes buildQTypeFromMinMax. Extracted from 30 // TOSA legalization infrastructure. 31 struct ConvertTosaNegateOp : public RewritePattern { 32 explicit ConvertTosaNegateOp(MLIRContext *context) 33 : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} 34 LogicalResult matchAndRewrite(Operation *op, 35 PatternRewriter &rewriter) const override; 36 }; 37 38 LogicalResult 39 ConvertTosaNegateOp::matchAndRewrite(Operation *op, 40 PatternRewriter &rewriter) const { 41 42 auto tosaNegateOp = cast<tosa::NegateOp>(op); 43 44 auto inputType = 45 dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType()); 46 // skip if input is not ranked tensor type 47 if (!inputType) 48 return failure(); 49 50 // skip if it's not ranked tensor type. 51 auto outputType = 52 dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType()); 53 if (!outputType) 54 return failure(); 55 56 // skip if output is not per-tensor quantized type. 57 auto outputElementType = 58 dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType()); 59 if (!outputElementType) 60 return failure(); 61 62 // skip if output is not uint8. 63 if (outputElementType.isSigned() || 64 outputElementType.getStorageTypeIntegralWidth() != 8) 65 return failure(); 66 67 double typeRangeMin = double(outputElementType.getStorageTypeMin() - 68 outputElementType.getZeroPoint()) * 69 outputElementType.getScale(); 70 double typeRangeMax = double(outputElementType.getStorageTypeMax() - 71 outputElementType.getZeroPoint()) * 72 outputElementType.getScale(); 73 bool narrowRange = outputElementType.getStorageTypeMin() == 1; 74 75 auto dstQConstType = RankedTensorType::get( 76 outputType.getShape(), 77 buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), 78 rewriter.getF64FloatAttr(typeRangeMin), 79 rewriter.getF64FloatAttr(typeRangeMax), 80 rewriter.getI32IntegerAttr( 81 outputElementType.getStorageTypeIntegralWidth()), 82 0, true /* signed */, 83 rewriter.getBoolAttr(narrowRange))); 84 85 ElementsAttr inputElems; 86 if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems))) 87 return failure(); 88 89 auto newConstOp = 90 rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems); 91 auto newNegateOp = rewriter.create<tosa::NegateOp>( 92 op->getLoc(), dstQConstType, newConstOp.getResult()); 93 94 rewriter.replaceOp(op, {newNegateOp.getResult()}); 95 return success(); 96 } 97 98 // This transformation modifies the quantized output of a test conv2d input and 99 // appends a TOSA rescale after it. The rescale op requires the invocation of 100 // computeMultiplierAndShift. From TOSA legalization infrastructure. 101 struct ConvertTosaConv2DOp : public RewritePattern { 102 explicit ConvertTosaConv2DOp(MLIRContext *context) 103 : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} 104 LogicalResult matchAndRewrite(Operation *op, 105 PatternRewriter &rewriter) const override; 106 }; 107 108 LogicalResult 109 ConvertTosaConv2DOp::matchAndRewrite(Operation *op, 110 PatternRewriter &rewriter) const { 111 112 auto tosaConv2DOp = cast<tosa::Conv2DOp>(op); 113 114 auto inputType = 115 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType()); 116 117 // skip if input is not ranked tensor type 118 if (!inputType) 119 return failure(); 120 121 auto weightType = 122 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType()); 123 124 // skip if wt is not ranked tensor type 125 if (!weightType) 126 return failure(); 127 128 // skip if it's not ranked tensor type. 129 auto outputType = 130 dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType()); 131 if (!outputType) 132 return failure(); 133 134 auto inputQType = 135 dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType()); 136 auto weightQType = 137 dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType()); 138 auto outputQType = 139 dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType()); 140 141 // Works on quantized type only. 142 if (!(inputQType && weightQType && outputQType)) 143 return failure(); 144 145 auto newTosaConv2DOpType = 146 RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); 147 148 auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>( 149 op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), 150 tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), 151 tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), 152 tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr()); 153 154 // Create rescale to quantized type 155 double inputScale = inputQType.getScale(); 156 double weightScale = weightQType.getScale(); 157 double outputScale = outputQType.getScale(); 158 int64_t outputZp = outputQType.getZeroPoint(); 159 160 double opTensorScale = (inputScale * weightScale) / outputScale; 161 162 int32_t multiplier; 163 int32_t shift; 164 165 // Obtain the quantized scale = multiplier and shift. 166 computeMultiplierAndShift(opTensorScale, multiplier, shift, 32); 167 168 auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>( 169 op->getLoc(), outputType, newTosaConv2DOp.getResult(), 170 rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp), 171 rewriter.getDenseI32ArrayAttr({multiplier}), 172 rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}), 173 rewriter.getBoolAttr(true), rewriter.getBoolAttr(true), 174 rewriter.getBoolAttr(false)); 175 176 rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); 177 return success(); 178 } 179 180 namespace { 181 182 struct TosaTestQuantUtilAPI 183 : public PassWrapper<TosaTestQuantUtilAPI, OperationPass<func::FuncOp>> { 184 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI) 185 186 StringRef getArgument() const final { return PASS_NAME; } 187 StringRef getDescription() const final { 188 return "TOSA Test: Exercise the APIs in QuantUtils.cpp."; 189 } 190 void runOnOperation() override; 191 }; 192 193 void TosaTestQuantUtilAPI::runOnOperation() { 194 auto *ctx = &getContext(); 195 RewritePatternSet patterns(ctx); 196 auto func = getOperation(); 197 198 patterns.add<ConvertTosaNegateOp>(ctx); 199 patterns.add<ConvertTosaConv2DOp>(ctx); 200 (void)applyPatternsGreedily(func, std::move(patterns)); 201 } 202 203 } // namespace 204 205 namespace mlir { 206 void registerTosaTestQuantUtilAPIPass() { 207 PassRegistration<TosaTestQuantUtilAPI>(); 208 } 209 } // namespace mlir 210