1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "QuantDialectBytecode.h" 10 #include "TypeDetail.h" 11 12 #include "mlir/Dialect/Quant/IR/Quant.h" 13 #include "mlir/Dialect/Quant/IR/QuantTypes.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 18 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" 19 20 21 namespace mlir { 22 namespace quant { 23 24 namespace { 25 26 // Verify the integrity of per-axis quantization information, if present. 27 // 28 // - quantizedType 29 // Any quantized type. Any quantized type with no per-axis quantization is 30 // ignored. 31 // 32 // - containerType 33 // Original input or result type of the operation using the provided quantized 34 // type. Used to ensure that the quantized type appears within a tensor and 35 // that the tensor is compatible with per-axis quantization information. 36 // 37 LogicalResult verifyPerAxisQuantization(Operation *op, 38 QuantizedType quantizedType, 39 Type containerType) { 40 auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType); 41 if (!quantizedPerAxisType) 42 return success(); 43 44 auto tensorType = dyn_cast<TensorType>(containerType); 45 if (!tensorType) 46 return op->emitError("scalar types may not use per-axis quantization"); 47 48 if (!tensorType.hasRank()) 49 return success(); 50 51 int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension(); 52 if (quantizedDimension >= tensorType.getRank()) 53 return op->emitError("quantized dimension must be less than tensor rank"); 54 55 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); 56 if (quantizedDimensionSize != ShapedType::kDynamic && 57 quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size()) 58 return op->emitError( 59 "quantized dimension size does not match number of scales"); 60 61 return success(); 62 } 63 64 // Common verification logic for 'quant.dcast' and 'quant.qcast' ops. 65 // 66 // - quantizedType 67 // Quantized type used in the input ('quant.dcast') or result ('quant.qcast'), 68 // whether as a primitive type or in a tensor. 69 // 70 // - floatType 71 // Float type used in the input ('quant.qcast') or result ('quant.dcast'), 72 // whether as a primitive type or in a tensor. 73 // 74 // - containerType 75 // Type of original input or result. 76 // 77 LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, 78 FloatType floatType, Type containerType) { 79 if (quantizedType.getExpressedType() != floatType) 80 return op->emitError( 81 "expressed type in quantized type expected to match float type"); 82 83 // Veriy integrity of per-axis quantization information, if present. 84 return verifyPerAxisQuantization(op, quantizedType, containerType); 85 } 86 87 } // namespace 88 89 90 //===----------------------------------------------------------------------===// 91 // Dialect 92 //===----------------------------------------------------------------------===// 93 94 void QuantDialect::initialize() { 95 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType, 96 UniformQuantizedPerAxisType>(); 97 addOperations< 98 #define GET_OP_LIST 99 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" 100 >(); 101 detail::addBytecodeInterface(this); 102 } 103 104 105 //===----------------------------------------------------------------------===// 106 // DequantizeCastOp 107 //===----------------------------------------------------------------------===// 108 109 LogicalResult DequantizeCastOp::verify() { 110 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), 111 getInput().getType()); 112 } 113 114 OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) { 115 // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op 116 // with the value of x. Values x and y are guaranteed to be of the same type 117 // in this pattern. 118 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>(); 119 if (!srcQcastOp) 120 return {}; 121 assert(srcQcastOp.getInput().getType() == getType()); 122 return srcQcastOp.getInput(); 123 } 124 125 FloatType DequantizeCastOp::getFloatType() { 126 return cast<FloatType>(getElementTypeOrSelf(getResult().getType())); 127 } 128 129 QuantizedType DequantizeCastOp::getQuantizedType() { 130 return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType())); 131 } 132 133 134 //===----------------------------------------------------------------------===// 135 // QuantizeCastOp 136 //===----------------------------------------------------------------------===// 137 138 LogicalResult QuantizeCastOp::verify() { 139 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), 140 getInput().getType()); 141 } 142 143 OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) { 144 // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op 145 // with the value of x if the casts invert each other. Contrary to the folding 146 // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values 147 // x and y are not guaranteed to be of the same type here, as they may use 148 // different quantization parameters. 149 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>(); 150 if (!srcDcastOp || srcDcastOp.getInput().getType() != getType()) 151 return {}; 152 return srcDcastOp.getInput(); 153 } 154 155 FloatType QuantizeCastOp::getFloatType() { 156 return cast<FloatType>(getElementTypeOrSelf(getInput().getType())); 157 } 158 159 QuantizedType QuantizeCastOp::getQuantizedType() { 160 return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType())); 161 } 162 163 164 //===----------------------------------------------------------------------===// 165 // StorageCastOp 166 //===----------------------------------------------------------------------===// 167 168 LogicalResult StorageCastOp::verify() { 169 auto quantizedType = getQuantizedType(); 170 auto integerType = getIntegerType(); 171 if (quantizedType.getStorageType() != integerType) 172 return emitError( 173 "storage type in quantized type expected to match integer type"); 174 175 // Verify integrity of per-axis quantization information, if available. While 176 // the quantization type may appear in the input or the result, their tensor 177 // shapes are guaranteed to be identical at this point. 178 return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); 179 } 180 181 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { 182 // Matches x -> quant.scast -> quant.scast -> y, replacing the second 183 // quant.scast with the value of x if the casts invert each other. 184 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>(); 185 if (!srcScastOp || srcScastOp.getInput().getType() != getType()) 186 return {}; 187 return srcScastOp.getInput(); 188 } 189 190 IntegerType StorageCastOp::getIntegerType() { 191 auto inputScalarType = getElementTypeOrSelf(getInput().getType()); 192 if (auto integerType = dyn_cast<IntegerType>(inputScalarType)) 193 return integerType; 194 195 auto resultScalarType = getElementTypeOrSelf(getResult().getType()); 196 return cast<IntegerType>(resultScalarType); 197 } 198 199 QuantizedType StorageCastOp::getQuantizedType() { 200 auto inputScalarType = getElementTypeOrSelf(getInput().getType()); 201 if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType)) 202 return quantizedType; 203 204 auto resultScalarType = getElementTypeOrSelf(getResult().getType()); 205 return cast<QuantizedType>(resultScalarType); 206 } 207 208 209 } // namespace quant 210 } // namespace mlir 211 212 #define GET_OP_CLASSES 213 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" 214 215