1 //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/Utils/UniformSupport.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include <numeric> 12 13 using namespace mlir; 14 using namespace mlir::quant; 15 16 static bool isQuantizablePrimitiveType(Type inputType) { 17 return isa<FloatType>(inputType); 18 } 19 20 ExpressedToQuantizedConverter 21 ExpressedToQuantizedConverter::forInputType(Type inputType) { 22 if (isa<TensorType, VectorType>(inputType)) { 23 Type elementType = cast<ShapedType>(inputType).getElementType(); 24 if (!isQuantizablePrimitiveType(elementType)) 25 return ExpressedToQuantizedConverter{inputType, nullptr}; 26 return ExpressedToQuantizedConverter{inputType, elementType}; 27 } 28 // Supported primitive type (which just is the expressed type). 29 if (isQuantizablePrimitiveType(inputType)) 30 return ExpressedToQuantizedConverter{inputType, inputType}; 31 // Unsupported. 32 return ExpressedToQuantizedConverter{inputType, nullptr}; 33 } 34 35 Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { 36 assert(expressedType && "convert() on unsupported conversion"); 37 if (auto tensorType = dyn_cast<RankedTensorType>(inputType)) 38 return RankedTensorType::get(tensorType.getShape(), elementalType); 39 if (dyn_cast<UnrankedTensorType>(inputType)) 40 return UnrankedTensorType::get(elementalType); 41 if (auto vectorType = dyn_cast<VectorType>(inputType)) 42 return VectorType::get(vectorType.getShape(), elementalType); 43 44 // If the expressed types match, just use the new elemental type. 45 if (elementalType.getExpressedType() == expressedType) 46 return elementalType; 47 // Unsupported. 48 return nullptr; 49 } 50 51 ElementsAttr 52 UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) { 53 if (auto attr = dyn_cast<DenseFPElementsAttr>(realValue)) { 54 return convert(attr); 55 } 56 // TODO: handles sparse elements attribute 57 return nullptr; 58 } 59 60 DenseElementsAttr 61 UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { 62 // Creates the converter for each chunk. Normally the size of the 63 // quantization dim is 3, so we can cache all the converters. 64 ShapedType type = attr.getType(); 65 size_t dimSize = type.getDimSize(quantizationDim); 66 if (dimSize != scales.size()) { 67 return {}; 68 } 69 SmallVector<UniformQuantizedValueConverter, 4> converters; 70 converters.reserve(dimSize); 71 for (int i = 0, e = dimSize; i != e; ++i) { 72 converters.push_back(getPerChunkConverter(i)); 73 } 74 75 // Scan the elements of the dense elements attributes and quantize them by 76 // using the right quantization parameters. 77 int64_t flattenIndex = 0; 78 auto shape = type.getShape(); 79 int64_t chunkSize = 80 std::accumulate(std::next(shape.begin(), quantizationDim + 1), 81 shape.end(), 1, std::multiplies<int64_t>()); 82 Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth); 83 return attr.mapValues(newElementType, [&](const APFloat &old) { 84 int chunkIndex = (flattenIndex++) / chunkSize; 85 return converters[chunkIndex % dimSize].quantizeFloatToInt(old); 86 }); 87 } 88