1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/IR/Quant.h" 10 #include "mlir/Dialect/Quant/IR/QuantTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Location.h" 14 #include "mlir/IR/Types.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/Support/Format.h" 17 #include "llvm/Support/MathExtras.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Support/raw_ostream.h" 20 21 using namespace mlir; 22 using namespace quant; 23 24 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { 25 auto typeLoc = parser.getCurrentLocation(); 26 IntegerType type; 27 28 // Parse storage type (alpha_ident, integer_literal). 29 StringRef identifier; 30 unsigned storageTypeWidth = 0; 31 OptionalParseResult result = parser.parseOptionalType(type); 32 if (result.has_value()) { 33 if (!succeeded(*result)) 34 return nullptr; 35 isSigned = !type.isUnsigned(); 36 storageTypeWidth = type.getWidth(); 37 } else if (succeeded(parser.parseKeyword(&identifier))) { 38 // Otherwise, this must be an unsigned integer (`u` integer-literal). 39 if (!identifier.consume_front("u")) { 40 parser.emitError(typeLoc, "illegal storage type prefix"); 41 return nullptr; 42 } 43 if (identifier.getAsInteger(10, storageTypeWidth)) { 44 parser.emitError(typeLoc, "expected storage type width"); 45 return nullptr; 46 } 47 isSigned = false; 48 type = parser.getBuilder().getIntegerType(storageTypeWidth); 49 } else { 50 return nullptr; 51 } 52 53 if (storageTypeWidth == 0 || 54 storageTypeWidth > QuantizedType::MaxStorageBits) { 55 parser.emitError(typeLoc, "illegal storage type size: ") 56 << storageTypeWidth; 57 return nullptr; 58 } 59 60 return type; 61 } 62 63 static ParseResult parseStorageRange(DialectAsmParser &parser, 64 IntegerType storageType, bool isSigned, 65 int64_t &storageTypeMin, 66 int64_t &storageTypeMax) { 67 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( 68 isSigned, storageType.getWidth()); 69 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( 70 isSigned, storageType.getWidth()); 71 if (failed(parser.parseOptionalLess())) { 72 storageTypeMin = defaultIntegerMin; 73 storageTypeMax = defaultIntegerMax; 74 return success(); 75 } 76 77 // Explicit storage min and storage max. 78 SMLoc minLoc = parser.getCurrentLocation(), maxLoc; 79 if (parser.parseInteger(storageTypeMin) || parser.parseColon() || 80 parser.getCurrentLocation(&maxLoc) || 81 parser.parseInteger(storageTypeMax) || parser.parseGreater()) 82 return failure(); 83 if (storageTypeMin < defaultIntegerMin) { 84 return parser.emitError(minLoc, "illegal storage type minimum: ") 85 << storageTypeMin; 86 } 87 if (storageTypeMax > defaultIntegerMax) { 88 return parser.emitError(maxLoc, "illegal storage type maximum: ") 89 << storageTypeMax; 90 } 91 return success(); 92 } 93 94 static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, 95 double &min, double &max) { 96 auto typeLoc = parser.getCurrentLocation(); 97 FloatType type; 98 99 if (failed(parser.parseType(type))) { 100 parser.emitError(typeLoc, "expecting float expressed type"); 101 return nullptr; 102 } 103 104 // Calibrated min and max values. 105 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() || 106 parser.parseFloat(max) || parser.parseGreater()) { 107 parser.emitError(typeLoc, "calibrated values must be present"); 108 return nullptr; 109 } 110 return type; 111 } 112 113 /// Parses an AnyQuantizedType. 114 /// 115 /// any ::= `any<` storage-spec (expressed-type-spec)?`>` 116 /// storage-spec ::= storage-type (`<` storage-range `>`)? 117 /// storage-range ::= integer-literal `:` integer-literal 118 /// storage-type ::= (`i` | `u`) integer-literal 119 /// expressed-type-spec ::= `:` `f` integer-literal 120 static Type parseAnyType(DialectAsmParser &parser) { 121 IntegerType storageType; 122 FloatType expressedType; 123 unsigned typeFlags = 0; 124 int64_t storageTypeMin; 125 int64_t storageTypeMax; 126 127 // Type specification. 128 if (parser.parseLess()) 129 return nullptr; 130 131 // Storage type. 132 bool isSigned = false; 133 storageType = parseStorageType(parser, isSigned); 134 if (!storageType) { 135 return nullptr; 136 } 137 if (isSigned) { 138 typeFlags |= QuantizationFlags::Signed; 139 } 140 141 // Storage type range. 142 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, 143 storageTypeMax)) { 144 return nullptr; 145 } 146 147 // Optional expressed type. 148 if (succeeded(parser.parseOptionalColon())) { 149 if (parser.parseType(expressedType)) { 150 return nullptr; 151 } 152 } 153 154 if (parser.parseGreater()) { 155 return nullptr; 156 } 157 158 return parser.getChecked<AnyQuantizedType>( 159 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); 160 } 161 162 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, 163 int64_t &zeroPoint) { 164 // scale[:zeroPoint]? 165 // scale. 166 if (parser.parseFloat(scale)) 167 return failure(); 168 169 // zero point. 170 zeroPoint = 0; 171 if (failed(parser.parseOptionalColon())) { 172 // Default zero point. 173 return success(); 174 } 175 176 return parser.parseInteger(zeroPoint); 177 } 178 179 /// Parses a UniformQuantizedType. 180 /// 181 /// uniform_type ::= uniform_per_layer 182 /// | uniform_per_axis 183 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec 184 /// `,` scale-zero `>` 185 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec 186 /// axis-spec `,` scale-zero-list `>` 187 /// storage-spec ::= storage-type (`<` storage-range `>`)? 188 /// storage-range ::= integer-literal `:` integer-literal 189 /// storage-type ::= (`i` | `u`) integer-literal 190 /// expressed-type-spec ::= `:` `f` integer-literal 191 /// axis-spec ::= `:` integer-literal 192 /// scale-zero ::= float-literal `:` integer-literal 193 /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` 194 static Type parseUniformType(DialectAsmParser &parser) { 195 IntegerType storageType; 196 FloatType expressedType; 197 unsigned typeFlags = 0; 198 int64_t storageTypeMin; 199 int64_t storageTypeMax; 200 bool isPerAxis = false; 201 int32_t quantizedDimension; 202 SmallVector<double, 1> scales; 203 SmallVector<int64_t, 1> zeroPoints; 204 205 // Type specification. 206 if (parser.parseLess()) { 207 return nullptr; 208 } 209 210 // Storage type. 211 bool isSigned = false; 212 storageType = parseStorageType(parser, isSigned); 213 if (!storageType) { 214 return nullptr; 215 } 216 if (isSigned) { 217 typeFlags |= QuantizationFlags::Signed; 218 } 219 220 // Storage type range. 221 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, 222 storageTypeMax)) { 223 return nullptr; 224 } 225 226 // Expressed type. 227 if (parser.parseColon() || parser.parseType(expressedType)) { 228 return nullptr; 229 } 230 231 // Optionally parse quantized dimension for per-axis quantization. 232 if (succeeded(parser.parseOptionalColon())) { 233 if (parser.parseInteger(quantizedDimension)) 234 return nullptr; 235 isPerAxis = true; 236 } 237 238 // Comma leading into range_spec. 239 if (parser.parseComma()) { 240 return nullptr; 241 } 242 243 // Parameter specification. 244 // For per-axis, ranges are in a {} delimitted list. 245 if (isPerAxis) { 246 if (parser.parseLBrace()) { 247 return nullptr; 248 } 249 } 250 251 // Parse scales/zeroPoints. 252 SMLoc scaleZPLoc = parser.getCurrentLocation(); 253 do { 254 scales.resize(scales.size() + 1); 255 zeroPoints.resize(zeroPoints.size() + 1); 256 if (parseQuantParams(parser, scales.back(), zeroPoints.back())) { 257 return nullptr; 258 } 259 } while (isPerAxis && succeeded(parser.parseOptionalComma())); 260 261 if (isPerAxis) { 262 if (parser.parseRBrace()) { 263 return nullptr; 264 } 265 } 266 267 if (parser.parseGreater()) { 268 return nullptr; 269 } 270 271 if (!isPerAxis && scales.size() > 1) { 272 return (parser.emitError(scaleZPLoc, 273 "multiple scales/zeroPoints provided, but " 274 "quantizedDimension wasn't specified"), 275 nullptr); 276 } 277 278 if (isPerAxis) { 279 ArrayRef<double> scalesRef(scales.begin(), scales.end()); 280 ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); 281 return parser.getChecked<UniformQuantizedPerAxisType>( 282 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, 283 quantizedDimension, storageTypeMin, storageTypeMax); 284 } 285 286 return parser.getChecked<UniformQuantizedType>( 287 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), 288 storageTypeMin, storageTypeMax); 289 } 290 291 /// Parses an CalibratedQuantizedType. 292 /// 293 /// calibrated ::= `calibrated<` expressed-spec `>` 294 /// expressed-spec ::= expressed-type `<` calibrated-range `>` 295 /// expressed-type ::= `f` integer-literal 296 /// calibrated-range ::= float-literal `:` float-literal 297 static Type parseCalibratedType(DialectAsmParser &parser) { 298 FloatType expressedType; 299 double min; 300 double max; 301 302 // Type specification. 303 if (parser.parseLess()) 304 return nullptr; 305 306 // Expressed type. 307 expressedType = parseExpressedTypeAndRange(parser, min, max); 308 if (!expressedType) { 309 return nullptr; 310 } 311 312 if (parser.parseGreater()) { 313 return nullptr; 314 } 315 316 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max); 317 } 318 319 /// Parse a type registered to this dialect. 320 Type QuantDialect::parseType(DialectAsmParser &parser) const { 321 // All types start with an identifier that we switch on. 322 StringRef typeNameSpelling; 323 if (failed(parser.parseKeyword(&typeNameSpelling))) 324 return nullptr; 325 326 if (typeNameSpelling == "uniform") 327 return parseUniformType(parser); 328 if (typeNameSpelling == "any") 329 return parseAnyType(parser); 330 if (typeNameSpelling == "calibrated") 331 return parseCalibratedType(parser); 332 333 parser.emitError(parser.getNameLoc(), 334 "unknown quantized type " + typeNameSpelling); 335 return nullptr; 336 } 337 338 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { 339 // storage type 340 unsigned storageWidth = type.getStorageTypeIntegralWidth(); 341 bool isSigned = type.isSigned(); 342 if (isSigned) { 343 out << "i" << storageWidth; 344 } else { 345 out << "u" << storageWidth; 346 } 347 348 // storageTypeMin and storageTypeMax if not default. 349 if (type.hasStorageTypeBounds()) { 350 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() 351 << ">"; 352 } 353 } 354 355 static void printQuantParams(double scale, int64_t zeroPoint, 356 DialectAsmPrinter &out) { 357 out << scale; 358 if (zeroPoint != 0) { 359 out << ":" << zeroPoint; 360 } 361 } 362 363 /// Helper that prints a AnyQuantizedType. 364 static void printAnyQuantizedType(AnyQuantizedType type, 365 DialectAsmPrinter &out) { 366 out << "any<"; 367 printStorageType(type, out); 368 if (Type expressedType = type.getExpressedType()) { 369 out << ":" << expressedType; 370 } 371 out << ">"; 372 } 373 374 /// Helper that prints a UniformQuantizedType. 375 static void printUniformQuantizedType(UniformQuantizedType type, 376 DialectAsmPrinter &out) { 377 out << "uniform<"; 378 printStorageType(type, out); 379 out << ":" << type.getExpressedType() << ", "; 380 381 // scheme specific parameters 382 printQuantParams(type.getScale(), type.getZeroPoint(), out); 383 out << ">"; 384 } 385 386 /// Helper that prints a UniformQuantizedPerAxisType. 387 static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, 388 DialectAsmPrinter &out) { 389 out << "uniform<"; 390 printStorageType(type, out); 391 out << ":" << type.getExpressedType() << ":"; 392 out << type.getQuantizedDimension(); 393 out << ", "; 394 395 // scheme specific parameters 396 ArrayRef<double> scales = type.getScales(); 397 ArrayRef<int64_t> zeroPoints = type.getZeroPoints(); 398 out << "{"; 399 llvm::interleave( 400 llvm::seq<size_t>(0, scales.size()), out, 401 [&](size_t index) { 402 printQuantParams(scales[index], zeroPoints[index], out); 403 }, 404 ","); 405 out << "}>"; 406 } 407 408 /// Helper that prints a CalibratedQuantizedType. 409 static void printCalibratedQuantizedType(CalibratedQuantizedType type, 410 DialectAsmPrinter &out) { 411 out << "calibrated<" << type.getExpressedType(); 412 out << "<" << type.getMin() << ":" << type.getMax() << ">"; 413 out << ">"; 414 } 415 416 /// Print a type registered to this dialect. 417 void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { 418 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type)) 419 printAnyQuantizedType(anyType, os); 420 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type)) 421 printUniformQuantizedType(uniformType, os); 422 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type)) 423 printUniformQuantizedPerAxisType(perAxisType, os); 424 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type)) 425 printCalibratedQuantizedType(calibratedType, os); 426 else 427 llvm_unreachable("Unhandled quantized type"); 428 } 429