1363dd3f3SRob Suderman //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 9*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Utils/UniformSupport.h" 1009f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 11363dd3f3SRob Suderman #include <numeric> 12363dd3f3SRob Suderman 13363dd3f3SRob Suderman using namespace mlir; 14363dd3f3SRob Suderman using namespace mlir::quant; 15363dd3f3SRob Suderman 16363dd3f3SRob Suderman static bool isQuantizablePrimitiveType(Type inputType) { 175550c821STres Popp return isa<FloatType>(inputType); 18363dd3f3SRob Suderman } 19363dd3f3SRob Suderman 20ad5d7aceSMehdi Amini ExpressedToQuantizedConverter 21363dd3f3SRob Suderman ExpressedToQuantizedConverter::forInputType(Type inputType) { 225550c821STres Popp if (isa<TensorType, VectorType>(inputType)) { 235550c821STres Popp Type elementType = cast<ShapedType>(inputType).getElementType(); 24c8c45985SRiver Riddle if (!isQuantizablePrimitiveType(elementType)) 25c8c45985SRiver Riddle return ExpressedToQuantizedConverter{inputType, nullptr}; 26c8c45985SRiver Riddle return ExpressedToQuantizedConverter{inputType, elementType}; 27c8c45985SRiver Riddle } 28c8c45985SRiver Riddle // Supported primitive type (which just is the expressed type). 29c8c45985SRiver Riddle if (isQuantizablePrimitiveType(inputType)) 30c8c45985SRiver Riddle return ExpressedToQuantizedConverter{inputType, inputType}; 31363dd3f3SRob Suderman // Unsupported. 32363dd3f3SRob Suderman return ExpressedToQuantizedConverter{inputType, nullptr}; 33363dd3f3SRob Suderman } 34363dd3f3SRob Suderman 35363dd3f3SRob Suderman Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { 36363dd3f3SRob Suderman assert(expressedType && "convert() on unsupported conversion"); 375550c821STres Popp if (auto tensorType = dyn_cast<RankedTensorType>(inputType)) 38c8c45985SRiver Riddle return RankedTensorType::get(tensorType.getShape(), elementalType); 390a0aff2dSMikhail Goncharov if (dyn_cast<UnrankedTensorType>(inputType)) 40c8c45985SRiver Riddle return UnrankedTensorType::get(elementalType); 415550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(inputType)) 42c8c45985SRiver Riddle return VectorType::get(vectorType.getShape(), elementalType); 43363dd3f3SRob Suderman 44942afe0cSFeng Liu // If the expressed types match, just use the new elemental type. 45c8c45985SRiver Riddle if (elementalType.getExpressedType() == expressedType) 46363dd3f3SRob Suderman return elementalType; 47363dd3f3SRob Suderman // Unsupported. 48363dd3f3SRob Suderman return nullptr; 49363dd3f3SRob Suderman } 50363dd3f3SRob Suderman 51363dd3f3SRob Suderman ElementsAttr 52363dd3f3SRob Suderman UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) { 535550c821STres Popp if (auto attr = dyn_cast<DenseFPElementsAttr>(realValue)) { 54363dd3f3SRob Suderman return convert(attr); 55363dd3f3SRob Suderman } 569db53a18SRiver Riddle // TODO: handles sparse elements attribute 57363dd3f3SRob Suderman return nullptr; 58363dd3f3SRob Suderman } 59363dd3f3SRob Suderman 60363dd3f3SRob Suderman DenseElementsAttr 61363dd3f3SRob Suderman UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { 62363dd3f3SRob Suderman // Creates the converter for each chunk. Normally the size of the 63363dd3f3SRob Suderman // quantization dim is 3, so we can cache all the converters. 64363dd3f3SRob Suderman ShapedType type = attr.getType(); 65363dd3f3SRob Suderman size_t dimSize = type.getDimSize(quantizationDim); 66363dd3f3SRob Suderman if (dimSize != scales.size()) { 67363dd3f3SRob Suderman return {}; 68363dd3f3SRob Suderman } 69363dd3f3SRob Suderman SmallVector<UniformQuantizedValueConverter, 4> converters; 70363dd3f3SRob Suderman converters.reserve(dimSize); 71363dd3f3SRob Suderman for (int i = 0, e = dimSize; i != e; ++i) { 72363dd3f3SRob Suderman converters.push_back(getPerChunkConverter(i)); 73363dd3f3SRob Suderman } 74363dd3f3SRob Suderman 75363dd3f3SRob Suderman // Scan the elements of the dense elements attributes and quantize them by 76363dd3f3SRob Suderman // using the right quantization parameters. 77363dd3f3SRob Suderman int64_t flattenIndex = 0; 78363dd3f3SRob Suderman auto shape = type.getShape(); 79363dd3f3SRob Suderman int64_t chunkSize = 80363dd3f3SRob Suderman std::accumulate(std::next(shape.begin(), quantizationDim + 1), 81363dd3f3SRob Suderman shape.end(), 1, std::multiplies<int64_t>()); 821b97cdf8SRiver Riddle Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth); 83363dd3f3SRob Suderman return attr.mapValues(newElementType, [&](const APFloat &old) { 84363dd3f3SRob Suderman int chunkIndex = (flattenIndex++) / chunkSize; 85363dd3f3SRob Suderman return converters[chunkIndex % dimSize].quantizeFloatToInt(old); 86363dd3f3SRob Suderman }); 87363dd3f3SRob Suderman } 88