1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TypeDetail.h" 10 #include "mlir/Dialect/Quant/IR/Quant.h" 11 #include "mlir/Dialect/Quant/IR/QuantTypes.h" 12 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/MathExtras.h" 18 19 using namespace mlir; 20 using namespace mlir::quant; 21 using namespace mlir::quant::detail; 22 23 namespace { 24 25 // Return the minimum scale representable in a given float type 26 double getMinScale(Type expressedType) { 27 auto floatType = cast<FloatType>(expressedType); 28 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); 29 } 30 31 // Return the maximum scale representable in a given float type 32 double getMaxScale(Type expressedType) { 33 auto floatType = cast<FloatType>(expressedType); 34 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); 35 } 36 37 } // namespace 38 39 unsigned QuantizedType::getFlags() const { 40 return static_cast<ImplType *>(impl)->flags; 41 } 42 43 bool QuantizedType::classof(Type type) { 44 return llvm::isa<QuantDialect>(type.getDialect()); 45 } 46 47 LogicalResult 48 QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 49 unsigned flags, Type storageType, 50 Type expressedType, int64_t storageTypeMin, 51 int64_t storageTypeMax) { 52 // Verify that the storage type is integral. 53 // This restriction may be lifted at some point in favor of using bf16 54 // or f16 as exact representations on hardware where that is advantageous. 55 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType); 56 if (!intStorageType) 57 return emitError() << "storage type must be integral"; 58 unsigned integralWidth = intStorageType.getWidth(); 59 60 // Verify storage width. 61 if (integralWidth == 0 || integralWidth > MaxStorageBits) 62 return emitError() << "illegal storage type size: " << integralWidth; 63 64 // Verify storageTypeMin and storageTypeMax. 65 bool isSigned = 66 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; 67 int64_t defaultIntegerMin = 68 getDefaultMinimumForInteger(isSigned, integralWidth); 69 int64_t defaultIntegerMax = 70 getDefaultMaximumForInteger(isSigned, integralWidth); 71 if (storageTypeMax - storageTypeMin <= 0 || 72 storageTypeMin < defaultIntegerMin || 73 storageTypeMax > defaultIntegerMax) { 74 return emitError() << "illegal storage min and storage max: (" 75 << storageTypeMin << ":" << storageTypeMax << ")"; 76 } 77 return success(); 78 } 79 80 Type QuantizedType::getStorageType() const { 81 return static_cast<ImplType *>(impl)->storageType; 82 } 83 84 int64_t QuantizedType::getStorageTypeMin() const { 85 return static_cast<ImplType *>(impl)->storageTypeMin; 86 } 87 88 int64_t QuantizedType::getStorageTypeMax() const { 89 return static_cast<ImplType *>(impl)->storageTypeMax; 90 } 91 92 bool QuantizedType::hasStorageTypeBounds() const { 93 unsigned int integralWidth = getStorageTypeIntegralWidth(); 94 bool isSignedInteger = isSigned(); 95 int64_t defaultIntegerMin = 96 getDefaultMinimumForInteger(isSignedInteger, integralWidth); 97 int64_t defaultIntegerMax = 98 getDefaultMaximumForInteger(isSignedInteger, integralWidth); 99 return defaultIntegerMin != getStorageTypeMin() || 100 defaultIntegerMax != getStorageTypeMax(); 101 } 102 103 unsigned QuantizedType::getStorageTypeIntegralWidth() const { 104 // NOTE: If ever supporting non-integral storage types, some other scheme 105 // for determining the width will be needed. 106 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth(); 107 } 108 109 Type QuantizedType::getExpressedType() const { 110 return static_cast<ImplType *>(impl)->expressedType; 111 } 112 113 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { 114 if (llvm::isa<ShapedType>(candidateExpressedType)) { 115 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() == 116 getExpressedType(); 117 } 118 return candidateExpressedType == getExpressedType(); 119 } 120 121 QuantizedType 122 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { 123 if (llvm::isa<ShapedType>(primitiveOrContainerType)) { 124 Type elementType = 125 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType(); 126 return llvm::dyn_cast<QuantizedType>(elementType); 127 } 128 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType); 129 } 130 131 Type QuantizedType::castFromStorageType(Type candidateType) { 132 if (candidateType == getStorageType()) { 133 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> 134 return *this; 135 } 136 if (llvm::isa<RankedTensorType>(candidateType)) { 137 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> 138 return RankedTensorType::get( 139 llvm::cast<RankedTensorType>(candidateType).getShape(), 140 getStorageType()); 141 } 142 if (llvm::isa<UnrankedTensorType>(candidateType)) { 143 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">> 144 return UnrankedTensorType::get(getStorageType()); 145 } 146 if (llvm::isa<VectorType>(candidateType)) { 147 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> 148 return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(), 149 getStorageType()); 150 } 151 152 return nullptr; 153 } 154 155 Type QuantizedType::castToStorageType(Type quantizedType) { 156 if (llvm::isa<QuantizedType>(quantizedType)) { 157 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 158 return llvm::cast<QuantizedType>(quantizedType).getStorageType(); 159 } 160 if (llvm::isa<ShapedType>(quantizedType)) { 161 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> 162 ShapedType sType = llvm::cast<ShapedType>(quantizedType); 163 if (!llvm::isa<QuantizedType>(sType.getElementType())) { 164 return nullptr; 165 } 166 Type storageType = 167 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType(); 168 if (llvm::isa<RankedTensorType>(quantizedType)) { 169 return RankedTensorType::get(sType.getShape(), storageType); 170 } 171 if (llvm::isa<UnrankedTensorType>(quantizedType)) { 172 return UnrankedTensorType::get(storageType); 173 } 174 if (llvm::isa<VectorType>(quantizedType)) { 175 return VectorType::get(sType.getShape(), storageType); 176 } 177 } 178 179 return nullptr; 180 } 181 182 Type QuantizedType::castFromExpressedType(Type candidateType) { 183 if (candidateType == getExpressedType()) { 184 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> 185 return *this; 186 } 187 if (llvm::isa<ShapedType>(candidateType)) { 188 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType); 189 if (candidateShapedType.getElementType() != getExpressedType()) { 190 return nullptr; 191 } 192 193 if (llvm::isa<RankedTensorType>(candidateType)) { 194 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> 195 return RankedTensorType::get(candidateShapedType.getShape(), *this); 196 } 197 if (llvm::isa<UnrankedTensorType>(candidateType)) { 198 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">> 199 return UnrankedTensorType::get(*this); 200 } 201 if (llvm::isa<VectorType>(candidateType)) { 202 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> 203 return VectorType::get(candidateShapedType.getShape(), *this); 204 } 205 } 206 207 return nullptr; 208 } 209 210 Type QuantizedType::castToExpressedType(Type quantizedType) { 211 if (llvm::isa<QuantizedType>(quantizedType)) { 212 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 213 return llvm::cast<QuantizedType>(quantizedType).getExpressedType(); 214 } 215 if (llvm::isa<ShapedType>(quantizedType)) { 216 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> 217 ShapedType sType = llvm::cast<ShapedType>(quantizedType); 218 if (!llvm::isa<QuantizedType>(sType.getElementType())) { 219 return nullptr; 220 } 221 Type expressedType = 222 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType(); 223 if (llvm::isa<RankedTensorType>(quantizedType)) { 224 return RankedTensorType::get(sType.getShape(), expressedType); 225 } 226 if (llvm::isa<UnrankedTensorType>(quantizedType)) { 227 return UnrankedTensorType::get(expressedType); 228 } 229 if (llvm::isa<VectorType>(quantizedType)) { 230 return VectorType::get(sType.getShape(), expressedType); 231 } 232 } 233 234 return nullptr; 235 } 236 237 Type QuantizedType::castExpressedToStorageType(Type candidateType) { 238 Type expressedQuantizedType = castFromExpressedType(candidateType); 239 if (!expressedQuantizedType) { 240 return nullptr; 241 } 242 return QuantizedType::castToStorageType(expressedQuantizedType); 243 } 244 245 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType, 246 Type expressedType, 247 int64_t storageTypeMin, 248 int64_t storageTypeMax) { 249 return Base::get(storageType.getContext(), flags, storageType, expressedType, 250 storageTypeMin, storageTypeMax); 251 } 252 253 AnyQuantizedType 254 AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError, 255 unsigned flags, Type storageType, 256 Type expressedType, int64_t storageTypeMin, 257 int64_t storageTypeMax) { 258 return Base::getChecked(emitError, storageType.getContext(), flags, 259 storageType, expressedType, storageTypeMin, 260 storageTypeMax); 261 } 262 263 LogicalResult 264 AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 265 unsigned flags, Type storageType, 266 Type expressedType, int64_t storageTypeMin, 267 int64_t storageTypeMax) { 268 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, 269 expressedType, storageTypeMin, 270 storageTypeMax))) { 271 return failure(); 272 } 273 274 // Verify that the expressed type is floating point. 275 // If this restriction is ever eliminated, the parser/printer must be 276 // extended. 277 if (expressedType && !llvm::isa<FloatType>(expressedType)) 278 return emitError() << "expressed type must be floating point"; 279 280 return success(); 281 } 282 283 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType, 284 Type expressedType, double scale, 285 int64_t zeroPoint, 286 int64_t storageTypeMin, 287 int64_t storageTypeMax) { 288 return Base::get(storageType.getContext(), flags, storageType, expressedType, 289 scale, zeroPoint, storageTypeMin, storageTypeMax); 290 } 291 292 UniformQuantizedType UniformQuantizedType::getChecked( 293 function_ref<InFlightDiagnostic()> emitError, unsigned flags, 294 Type storageType, Type expressedType, double scale, int64_t zeroPoint, 295 int64_t storageTypeMin, int64_t storageTypeMax) { 296 return Base::getChecked(emitError, storageType.getContext(), flags, 297 storageType, expressedType, scale, zeroPoint, 298 storageTypeMin, storageTypeMax); 299 } 300 301 LogicalResult UniformQuantizedType::verifyInvariants( 302 function_ref<InFlightDiagnostic()> emitError, unsigned flags, 303 Type storageType, Type expressedType, double scale, int64_t zeroPoint, 304 int64_t storageTypeMin, int64_t storageTypeMax) { 305 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, 306 expressedType, storageTypeMin, 307 storageTypeMax))) { 308 return failure(); 309 } 310 311 // Uniform quantization requires fully expressed parameters, including 312 // expressed type. 313 if (!expressedType) 314 return emitError() << "uniform quantization requires expressed type"; 315 316 // Verify that the expressed type is floating point. 317 // If this restriction is ever eliminated, the parser/printer must be 318 // extended. 319 if (!llvm::isa<FloatType>(expressedType)) 320 return emitError() << "expressed type must be floating point"; 321 322 // Verify scale. 323 double minScale = getMinScale(expressedType); 324 double maxScale = getMaxScale(expressedType); 325 if (scale < minScale || scale > maxScale) 326 return emitError() << "scale out of expressed type range [" << minScale 327 << ", " << maxScale << "]"; 328 329 return success(); 330 } 331 332 double UniformQuantizedType::getScale() const { return getImpl()->scale; } 333 334 int64_t UniformQuantizedType::getZeroPoint() const { 335 return getImpl()->zeroPoint; 336 } 337 338 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get( 339 unsigned flags, Type storageType, Type expressedType, 340 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 341 int32_t quantizedDimension, int64_t storageTypeMin, 342 int64_t storageTypeMax) { 343 return Base::get(storageType.getContext(), flags, storageType, expressedType, 344 scales, zeroPoints, quantizedDimension, storageTypeMin, 345 storageTypeMax); 346 } 347 348 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( 349 function_ref<InFlightDiagnostic()> emitError, unsigned flags, 350 Type storageType, Type expressedType, ArrayRef<double> scales, 351 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, 352 int64_t storageTypeMin, int64_t storageTypeMax) { 353 return Base::getChecked(emitError, storageType.getContext(), flags, 354 storageType, expressedType, scales, zeroPoints, 355 quantizedDimension, storageTypeMin, storageTypeMax); 356 } 357 358 LogicalResult UniformQuantizedPerAxisType::verifyInvariants( 359 function_ref<InFlightDiagnostic()> emitError, unsigned flags, 360 Type storageType, Type expressedType, ArrayRef<double> scales, 361 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, 362 int64_t storageTypeMin, int64_t storageTypeMax) { 363 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, 364 expressedType, storageTypeMin, 365 storageTypeMax))) { 366 return failure(); 367 } 368 369 // Uniform quantization requires fully expressed parameters, including 370 // expressed type. 371 if (!expressedType) 372 return emitError() << "uniform quantization requires expressed type"; 373 374 // Verify that the expressed type is floating point. 375 // If this restriction is ever eliminated, the parser/printer must be 376 // extended. 377 if (!llvm::isa<FloatType>(expressedType)) 378 return emitError() << "expressed type must be floating point"; 379 380 // Ensure that the number of scales and zeroPoints match. 381 if (scales.size() != zeroPoints.size()) 382 return emitError() << "illegal number of scales and zeroPoints: " 383 << scales.size() << ", " << zeroPoints.size(); 384 385 // Verify scale. 386 double minScale = getMinScale(expressedType); 387 double maxScale = getMaxScale(expressedType); 388 for (double scale : scales) { 389 if (scale < minScale || scale > maxScale) 390 return emitError() << "scale out of expressed type range [" << minScale 391 << ", " << maxScale << "]"; 392 } 393 394 // Verify quantized dimension. 395 if (quantizedDimension < 0) 396 return emitError() << "illegal quantized dimension: " << quantizedDimension; 397 398 return success(); 399 } 400 401 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const { 402 return getImpl()->getScales(); 403 } 404 405 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const { 406 return getImpl()->getZeroPoints(); 407 } 408 409 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { 410 return getImpl()->quantizedDimension; 411 } 412 413 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, 414 double min, double max) { 415 return Base::get(expressedType.getContext(), expressedType, min, max); 416 } 417 418 CalibratedQuantizedType CalibratedQuantizedType::getChecked( 419 function_ref<InFlightDiagnostic()> emitError, Type expressedType, 420 double min, double max) { 421 return Base::getChecked(emitError, expressedType.getContext(), expressedType, 422 min, max); 423 } 424 425 LogicalResult CalibratedQuantizedType::verifyInvariants( 426 function_ref<InFlightDiagnostic()> emitError, Type expressedType, 427 double min, double max) { 428 // Verify that the expressed type is floating point. 429 // If this restriction is ever eliminated, the parser/printer must be 430 // extended. 431 if (!llvm::isa<FloatType>(expressedType)) 432 return emitError() << "expressed type must be floating point"; 433 if (max <= min) 434 return emitError() << "illegal min and max: (" << min << ":" << max << ")"; 435 436 return success(); 437 } 438 439 double CalibratedQuantizedType::getMin() const { return getImpl()->min; } 440 441 double CalibratedQuantizedType::getMax() const { return getImpl()->max; } 442