1*852b6486SRafael Ubal //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===// 2*852b6486SRafael Ubal // 3*852b6486SRafael Ubal // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*852b6486SRafael Ubal // See https://llvm.org/LICENSE.txt for license information. 5*852b6486SRafael Ubal // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*852b6486SRafael Ubal // 7*852b6486SRafael Ubal //===----------------------------------------------------------------------===// 8*852b6486SRafael Ubal // 9*852b6486SRafael Ubal // Transforms `quant.dcast` and `quant.qcast` into lower-level ops. 10*852b6486SRafael Ubal // 11*852b6486SRafael Ubal //===----------------------------------------------------------------------===// 12*852b6486SRafael Ubal 13*852b6486SRafael Ubal #include "mlir/Dialect/Arith/IR/Arith.h" 14*852b6486SRafael Ubal #include "mlir/Dialect/Func/IR/FuncOps.h" 15*852b6486SRafael Ubal #include "mlir/Dialect/Linalg/IR/Linalg.h" 16*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/Quant.h" 17*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/QuantTypes.h" 18*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Transforms/Passes.h" 19*852b6486SRafael Ubal #include "mlir/Dialect/Shape/IR/Shape.h" 20*852b6486SRafael Ubal #include "mlir/Dialect/Tensor/IR/Tensor.h" 21*852b6486SRafael Ubal #include "mlir/IR/Matchers.h" 22*852b6486SRafael Ubal #include "mlir/IR/PatternMatch.h" 23*852b6486SRafael Ubal #include "mlir/Transforms/DialectConversion.h" 24*852b6486SRafael Ubal 25*852b6486SRafael Ubal namespace mlir { 26*852b6486SRafael Ubal namespace quant { 27*852b6486SRafael Ubal 28*852b6486SRafael Ubal #define GEN_PASS_DEF_LOWERQUANTOPS 29*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Transforms/Passes.h.inc" 30*852b6486SRafael Ubal 31*852b6486SRafael Ubal namespace { 32*852b6486SRafael Ubal 33*852b6486SRafael Ubal // If 'inputType' is a tensor, return its element type. If it is a scalar, 34*852b6486SRafael Ubal // return it as is. 35*852b6486SRafael Ubal Type getScalarType(Type inputType) { 36*852b6486SRafael Ubal if (auto tensorType = dyn_cast<TensorType>(inputType)) 37*852b6486SRafael Ubal return tensorType.getElementType(); 38*852b6486SRafael Ubal return inputType; 39*852b6486SRafael Ubal } 40*852b6486SRafael Ubal 41*852b6486SRafael Ubal // Return the shape of an input value as a list of attributes (static dimensions) 42*852b6486SRafael Ubal // and values (dynamic dimensions). If 'input' is a scalar, an empty list is 43*852b6486SRafael Ubal // returned. If 'input' is a tensor, its shape is returned. 44*852b6486SRafael Ubal SmallVector<OpFoldResult> 45*852b6486SRafael Ubal getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { 46*852b6486SRafael Ubal if (isa<TensorType>(input.getType())) 47*852b6486SRafael Ubal return tensor::getMixedSizes(builder, loc, input); 48*852b6486SRafael Ubal return {}; 49*852b6486SRafael Ubal } 50*852b6486SRafael Ubal 51*852b6486SRafael Ubal // If 'referenceType' is a scalar, return 'elementType' as is. If 52*852b6486SRafael Ubal // 'referenceType' is a tensor, return another tensor with the same shape and 53*852b6486SRafael Ubal // elements of type 'elementType'. 54*852b6486SRafael Ubal Type getScalarOrTensorType(Type elementType, Type referenceType) { 55*852b6486SRafael Ubal if (auto tensorType = dyn_cast<TensorType>(referenceType)) 56*852b6486SRafael Ubal return tensorType.clone(elementType); 57*852b6486SRafael Ubal return elementType; 58*852b6486SRafael Ubal } 59*852b6486SRafael Ubal 60*852b6486SRafael Ubal // Return a constant with the given value. If 'referenceType' is a tensor, a 61*852b6486SRafael Ubal // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a 62*852b6486SRafael Ubal // scalar, 'referenceShape' is ignored and a scalar constant is returned. 63*852b6486SRafael Ubal Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, 64*852b6486SRafael Ubal Type referenceType, 65*852b6486SRafael Ubal ArrayRef<OpFoldResult> referenceShape) { 66*852b6486SRafael Ubal // If the result type is a scalar, return the unmodified scalar constant. 67*852b6486SRafael Ubal auto tensorType = dyn_cast<TensorType>(referenceType); 68*852b6486SRafael Ubal if (!tensorType) { 69*852b6486SRafael Ubal assert(referenceShape.empty()); 70*852b6486SRafael Ubal return scalar; 71*852b6486SRafael Ubal } 72*852b6486SRafael Ubal 73*852b6486SRafael Ubal // Create tensor splat 74*852b6486SRafael Ubal auto tensorConstant = 75*852b6486SRafael Ubal builder.create<tensor::SplatOp>(loc, scalar, referenceShape); 76*852b6486SRafael Ubal return tensorConstant; 77*852b6486SRafael Ubal } 78*852b6486SRafael Ubal 79*852b6486SRafael Ubal // Reshape an unranked tensor into a 1D ranked tensor. 80*852b6486SRafael Ubal // 81*852b6486SRafael Ubal // - input 82*852b6486SRafael Ubal // Unranked tensor. 83*852b6486SRafael Ubal // 84*852b6486SRafael Ubal // Return values: 85*852b6486SRafael Ubal // 86*852b6486SRafael Ubal // - flatInput 87*852b6486SRafael Ubal // 1D ranked, dynamically shaped tensor. 88*852b6486SRafael Ubal // 89*852b6486SRafael Ubal // - inputShape 90*852b6486SRafael Ubal // 1D extent tensor containing the shape of the original unranked input. 91*852b6486SRafael Ubal // 92*852b6486SRafael Ubal std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc, 93*852b6486SRafael Ubal Value input) { 94*852b6486SRafael Ubal // Get unranked input shape and total size 95*852b6486SRafael Ubal auto *context = builder.getContext(); 96*852b6486SRafael Ubal auto shapeType = shape::getExtentTensorType(context); 97*852b6486SRafael Ubal auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input); 98*852b6486SRafael Ubal Value inputSize = builder.create<shape::NumElementsOp>( 99*852b6486SRafael Ubal loc, builder.getIndexType(), inputShape); 100*852b6486SRafael Ubal 101*852b6486SRafael Ubal // Turn input size into 1D tensor 102*852b6486SRafael Ubal auto flatShapeType = shape::getExtentTensorType(context, 1); 103*852b6486SRafael Ubal auto flatInputShape = builder.create<tensor::FromElementsOp>( 104*852b6486SRafael Ubal loc, flatShapeType, inputSize); 105*852b6486SRafael Ubal 106*852b6486SRafael Ubal // Reshape input tensor into 1D 107*852b6486SRafael Ubal auto inputType = cast<UnrankedTensorType>(input.getType()); 108*852b6486SRafael Ubal auto elementType = inputType.getElementType(); 109*852b6486SRafael Ubal auto flatInputType = 110*852b6486SRafael Ubal RankedTensorType::get({ShapedType::kDynamic}, elementType); 111*852b6486SRafael Ubal auto flatInput = builder.create<tensor::ReshapeOp>( 112*852b6486SRafael Ubal loc, flatInputType, input, flatInputShape); 113*852b6486SRafael Ubal return std::make_pair(flatInput, inputShape); 114*852b6486SRafael Ubal } 115*852b6486SRafael Ubal 116*852b6486SRafael Ubal // Reshape an unranked tensor into a 3D ranked tensor where the central 117*852b6486SRafael Ubal // dimension of the result tensor corresponds to dimension 'axis' of the input 118*852b6486SRafael Ubal // tensor. 119*852b6486SRafael Ubal // 120*852b6486SRafael Ubal // - input 121*852b6486SRafael Ubal // Unranked tensor. 122*852b6486SRafael Ubal // 123*852b6486SRafael Ubal // - axis 124*852b6486SRafael Ubal // Index of the input dimension around which other input dimiensions will be 125*852b6486SRafael Ubal // collapsed. 126*852b6486SRafael Ubal // 127*852b6486SRafael Ubal // - axisSize 128*852b6486SRafael Ubal // Size of input dimension 'axis'. 129*852b6486SRafael Ubal // 130*852b6486SRafael Ubal // Return values: 131*852b6486SRafael Ubal // 132*852b6486SRafael Ubal // - flatInput 133*852b6486SRafael Ubal // 3D ranked tensor of shape [?, axisSize, ?]. 134*852b6486SRafael Ubal // 135*852b6486SRafael Ubal // - inputShape 136*852b6486SRafael Ubal // 1D extent tensor containing the shape of the original unranked input. 137*852b6486SRafael Ubal // 138*852b6486SRafael Ubal std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder, 139*852b6486SRafael Ubal Location loc, 140*852b6486SRafael Ubal Value input, 141*852b6486SRafael Ubal int64_t axis, 142*852b6486SRafael Ubal int64_t axisSize) { 143*852b6486SRafael Ubal // Get full tensor shape 144*852b6486SRafael Ubal auto *context = builder.getContext(); 145*852b6486SRafael Ubal auto indexType = builder.getIndexType(); 146*852b6486SRafael Ubal auto shapeType = shape::getExtentTensorType(context); 147*852b6486SRafael Ubal auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input); 148*852b6486SRafael Ubal 149*852b6486SRafael Ubal // Get shape and sizes on left and right of axis 150*852b6486SRafael Ubal auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis); 151*852b6486SRafael Ubal auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1); 152*852b6486SRafael Ubal auto shapeLeft = builder.create<shape::SplitAtOp>( 153*852b6486SRafael Ubal loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) 154*852b6486SRafael Ubal .getResult(0); 155*852b6486SRafael Ubal auto sizeLeft = builder.create<shape::NumElementsOp>( 156*852b6486SRafael Ubal loc, indexType, shapeLeft); 157*852b6486SRafael Ubal auto shapeRight = builder.create<shape::SplitAtOp>( 158*852b6486SRafael Ubal loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) 159*852b6486SRafael Ubal .getResult(1); 160*852b6486SRafael Ubal auto sizeRight = builder.create<shape::NumElementsOp>( 161*852b6486SRafael Ubal loc, indexType, shapeRight); 162*852b6486SRafael Ubal 163*852b6486SRafael Ubal // Compute flat input shape as a 3-element 1D tensor 164*852b6486SRafael Ubal auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize); 165*852b6486SRafael Ubal auto flatShapeType = shape::getExtentTensorType(context, 3); 166*852b6486SRafael Ubal auto flatInputShape = builder.create<tensor::FromElementsOp>( 167*852b6486SRafael Ubal loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); 168*852b6486SRafael Ubal 169*852b6486SRafael Ubal // Reshape input to 3D tensor 170*852b6486SRafael Ubal auto inputType = cast<UnrankedTensorType>(input.getType()); 171*852b6486SRafael Ubal auto elementType = inputType.getElementType(); 172*852b6486SRafael Ubal auto flatInputType = RankedTensorType::get( 173*852b6486SRafael Ubal {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); 174*852b6486SRafael Ubal auto flatInput = builder.create<tensor::ReshapeOp>( 175*852b6486SRafael Ubal loc, flatInputType, input, flatInputShape); 176*852b6486SRafael Ubal 177*852b6486SRafael Ubal return std::make_pair(flatInput, inputShape); 178*852b6486SRafael Ubal } 179*852b6486SRafael Ubal 180*852b6486SRafael Ubal // Reshape an input tensor into its original unranked shape. 181*852b6486SRafael Ubal // 182*852b6486SRafael Ubal // - input 183*852b6486SRafael Ubal // Ranked tensor. 184*852b6486SRafael Ubal // 185*852b6486SRafael Ubal // - inputShape 186*852b6486SRafael Ubal // 1D extent tensor. 187*852b6486SRafael Ubal // 188*852b6486SRafael Ubal Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, 189*852b6486SRafael Ubal Value inputShape) { 190*852b6486SRafael Ubal auto inputType = cast<RankedTensorType>(input.getType()); 191*852b6486SRafael Ubal auto elementType = inputType.getElementType(); 192*852b6486SRafael Ubal auto unrankedType = UnrankedTensorType::get(elementType); 193*852b6486SRafael Ubal return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape); 194*852b6486SRafael Ubal } 195*852b6486SRafael Ubal 196*852b6486SRafael Ubal // Create a tensor constant containing all scales in a per-channel quantized 197*852b6486SRafael Ubal // type. Example: 198*852b6486SRafael Ubal // 199*852b6486SRafael Ubal // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}> 200*852b6486SRafael Ubal // 201*852b6486SRafael Ubal // produces 202*852b6486SRafael Ubal // 203*852b6486SRafael Ubal // %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> 204*852b6486SRafael Ubal // 205*852b6486SRafael Ubal Value materializePerChannelScales(OpBuilder &builder, Location loc, 206*852b6486SRafael Ubal UniformQuantizedPerAxisType quantizedType) { 207*852b6486SRafael Ubal auto scales = quantizedType.getScales(); 208*852b6486SRafael Ubal auto expressedType = quantizedType.getExpressedType(); 209*852b6486SRafael Ubal auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { 210*852b6486SRafael Ubal return builder.getFloatAttr(expressedType, scale); 211*852b6486SRafael Ubal }); 212*852b6486SRafael Ubal auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType); 213*852b6486SRafael Ubal auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); 214*852b6486SRafael Ubal return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr); 215*852b6486SRafael Ubal } 216*852b6486SRafael Ubal 217*852b6486SRafael Ubal // Create a tensor constant containing all zero points in a per-channel 218*852b6486SRafael Ubal // quantized type. Example: 219*852b6486SRafael Ubal // 220*852b6486SRafael Ubal // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}> 221*852b6486SRafael Ubal // 222*852b6486SRafael Ubal // produces 223*852b6486SRafael Ubal // 224*852b6486SRafael Ubal // %cst = arith.constant dense<[10, 20]> : tensor<2xi8> 225*852b6486SRafael Ubal // 226*852b6486SRafael Ubal Value materializePerChannelZeroPoints( 227*852b6486SRafael Ubal OpBuilder &builder, Location loc, 228*852b6486SRafael Ubal UniformQuantizedPerAxisType quantizedType) { 229*852b6486SRafael Ubal auto zeroPoints = quantizedType.getZeroPoints(); 230*852b6486SRafael Ubal auto storageType = quantizedType.getStorageType(); 231*852b6486SRafael Ubal auto zeroPointAttrs = llvm::map_to_vector( 232*852b6486SRafael Ubal zeroPoints, 233*852b6486SRafael Ubal [&](int64_t zeroPoint) -> Attribute { 234*852b6486SRafael Ubal return builder.getIntegerAttr(storageType, zeroPoint); 235*852b6486SRafael Ubal }); 236*852b6486SRafael Ubal auto tensorType = 237*852b6486SRafael Ubal RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); 238*852b6486SRafael Ubal auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); 239*852b6486SRafael Ubal return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr); 240*852b6486SRafael Ubal } 241*852b6486SRafael Ubal 242*852b6486SRafael Ubal // Clamp the given scalar or tensor input using the storage bounds encoded in 243*852b6486SRafael Ubal // the given quantized type, if present. 244*852b6486SRafael Ubal // 245*852b6486SRafael Ubal // - input 246*852b6486SRafael Ubal // Scalar or ranked tensor input. The element type must match the storage type 247*852b6486SRafael Ubal // of 'quantizedType'. 248*852b6486SRafael Ubal // 249*852b6486SRafael Ubal // - inputShape 250*852b6486SRafael Ubal // If 'input' is a tensor, combination of attributes/values representing its 251*852b6486SRafael Ubal // static/dynamic dimensions. If 'input' is a scalar, empty list. 252*852b6486SRafael Ubal // 253*852b6486SRafael Ubal // - quantizedType 254*852b6486SRafael Ubal // Per-axis or per-channel quantized type. 255*852b6486SRafael Ubal Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, 256*852b6486SRafael Ubal ArrayRef<OpFoldResult> inputShape, 257*852b6486SRafael Ubal QuantizedType quantizedType) { 258*852b6486SRafael Ubal // If quantized type does not narrow down the storage type range, there is 259*852b6486SRafael Ubal // nothing to do. 260*852b6486SRafael Ubal if (!quantizedType.hasStorageTypeBounds()) 261*852b6486SRafael Ubal return input; 262*852b6486SRafael Ubal 263*852b6486SRafael Ubal // Materialize bounds 264*852b6486SRafael Ubal auto inputType = input.getType(); 265*852b6486SRafael Ubal auto storageType = quantizedType.getStorageType(); 266*852b6486SRafael Ubal auto storageMinScalar = builder.create<arith::ConstantIntOp>( 267*852b6486SRafael Ubal loc, quantizedType.getStorageTypeMin(), storageType); 268*852b6486SRafael Ubal auto storageMaxScalar = builder.create<arith::ConstantIntOp>( 269*852b6486SRafael Ubal loc, quantizedType.getStorageTypeMax(), storageType); 270*852b6486SRafael Ubal auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, 271*852b6486SRafael Ubal inputType, inputShape); 272*852b6486SRafael Ubal auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar, 273*852b6486SRafael Ubal inputType, inputShape); 274*852b6486SRafael Ubal 275*852b6486SRafael Ubal // Clamp 276*852b6486SRafael Ubal if (quantizedType.isSigned()) { 277*852b6486SRafael Ubal input = builder.create<arith::MaxSIOp>(loc, input, storageMin); 278*852b6486SRafael Ubal input = builder.create<arith::MinSIOp>(loc, input, storageMax); 279*852b6486SRafael Ubal } else { 280*852b6486SRafael Ubal input = builder.create<arith::MaxUIOp>(loc, input, storageMin); 281*852b6486SRafael Ubal input = builder.create<arith::MinUIOp>(loc, input, storageMax); 282*852b6486SRafael Ubal } 283*852b6486SRafael Ubal return input; 284*852b6486SRafael Ubal } 285*852b6486SRafael Ubal 286*852b6486SRafael Ubal // Emit op 'arith.fptosi' or 'arith.fptoui'. 287*852b6486SRafael Ubal Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, 288*852b6486SRafael Ubal Type resultType, bool isSigned) { 289*852b6486SRafael Ubal if (isSigned) 290*852b6486SRafael Ubal return builder.create<arith::FPToSIOp>(loc, resultType, input); 291*852b6486SRafael Ubal return builder.create<arith::FPToUIOp>(loc, resultType, input); 292*852b6486SRafael Ubal } 293*852b6486SRafael Ubal 294*852b6486SRafael Ubal // Emit op 'arith.sitofp' or 'arith.uitofp'. 295*852b6486SRafael Ubal Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, 296*852b6486SRafael Ubal Type resultType, bool isSigned) { 297*852b6486SRafael Ubal if (isSigned) 298*852b6486SRafael Ubal return builder.create<arith::SIToFPOp>(loc, resultType, input); 299*852b6486SRafael Ubal return builder.create<arith::UIToFPOp>(loc, resultType, input); 300*852b6486SRafael Ubal } 301*852b6486SRafael Ubal 302*852b6486SRafael Ubal // Quantize a scalar or ranked tensor value. The stored value is clamped using 303*852b6486SRafael Ubal // the storage bounds encoded in the given quantized type. 304*852b6486SRafael Ubal // 305*852b6486SRafael Ubal // See function 'convertRanked()' below for a description of the arguments. 306*852b6486SRafael Ubal Value quantizeValue(OpBuilder &builder, Location loc, Value input, 307*852b6486SRafael Ubal ArrayRef<OpFoldResult> inputShape, Value scale, 308*852b6486SRafael Ubal Value zeroPoint, QuantizedType quantizedType) { 309*852b6486SRafael Ubal // Convert scale to tensor if necessary 310*852b6486SRafael Ubal auto inputType = input.getType(); 311*852b6486SRafael Ubal scale = getScalarOrTensorConstant( 312*852b6486SRafael Ubal builder, loc, scale, inputType, inputShape); 313*852b6486SRafael Ubal 314*852b6486SRafael Ubal // Scale input 315*852b6486SRafael Ubal auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale); 316*852b6486SRafael Ubal 317*852b6486SRafael Ubal // Skip unnecessary computations if no zero point is given 318*852b6486SRafael Ubal Value storedValueFloat = scaledValue; 319*852b6486SRafael Ubal if (!matchPattern(zeroPoint, m_Zero())) { 320*852b6486SRafael Ubal // Convert zero point to tensor if necessary 321*852b6486SRafael Ubal zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, 322*852b6486SRafael Ubal inputShape); 323*852b6486SRafael Ubal 324*852b6486SRafael Ubal // Convert zero point from storage to expressed type 325*852b6486SRafael Ubal zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, 326*852b6486SRafael Ubal scale.getType(), 327*852b6486SRafael Ubal quantizedType.isSigned()); 328*852b6486SRafael Ubal 329*852b6486SRafael Ubal // Add zero point to stored value 330*852b6486SRafael Ubal storedValueFloat = 331*852b6486SRafael Ubal builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint); 332*852b6486SRafael Ubal } 333*852b6486SRafael Ubal 334*852b6486SRafael Ubal // Convert stored value to storage type 335*852b6486SRafael Ubal auto storageScalarOrTensorType = 336*852b6486SRafael Ubal getScalarOrTensorType(quantizedType.getStorageType(), inputType); 337*852b6486SRafael Ubal auto storedValueInt = convertFloatToInteger( 338*852b6486SRafael Ubal builder, loc, storedValueFloat, storageScalarOrTensorType, 339*852b6486SRafael Ubal quantizedType.isSigned()); 340*852b6486SRafael Ubal 341*852b6486SRafael Ubal // Clamp stored value it if the storage type is bound 342*852b6486SRafael Ubal auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt, 343*852b6486SRafael Ubal inputShape, quantizedType); 344*852b6486SRafael Ubal return storedValueClamped; 345*852b6486SRafael Ubal } 346*852b6486SRafael Ubal 347*852b6486SRafael Ubal // Dequantize a scalar or ranked tensor input. 348*852b6486SRafael Ubal // 349*852b6486SRafael Ubal // See function 'convertRanked()' below for a description of the arguments. 350*852b6486SRafael Ubal Value dequantizeValue(OpBuilder &builder, Location loc, Value input, 351*852b6486SRafael Ubal ArrayRef<OpFoldResult> inputShape, Value scale, 352*852b6486SRafael Ubal Value zeroPoint, QuantizedType quantizedType) { 353*852b6486SRafael Ubal // Convert scale to tensor if necessary 354*852b6486SRafael Ubal auto inputType = input.getType(); 355*852b6486SRafael Ubal scale = getScalarOrTensorConstant( 356*852b6486SRafael Ubal builder, loc, scale, inputType, inputShape); 357*852b6486SRafael Ubal 358*852b6486SRafael Ubal // Convert stored value to float 359*852b6486SRafael Ubal auto result = convertIntegerToFloat( 360*852b6486SRafael Ubal builder, loc, input, scale.getType(), quantizedType.isSigned()); 361*852b6486SRafael Ubal 362*852b6486SRafael Ubal // Skip unnecessary computations if no zero point is given 363*852b6486SRafael Ubal if (!matchPattern(zeroPoint, m_Zero())) { 364*852b6486SRafael Ubal // Convert zero point to tensor if necessary 365*852b6486SRafael Ubal zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, 366*852b6486SRafael Ubal inputShape); 367*852b6486SRafael Ubal 368*852b6486SRafael Ubal // Convert zero point from storage to expressed type 369*852b6486SRafael Ubal zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, 370*852b6486SRafael Ubal scale.getType(), 371*852b6486SRafael Ubal quantizedType.isSigned()); 372*852b6486SRafael Ubal 373*852b6486SRafael Ubal // Subtract zero point to stored value 374*852b6486SRafael Ubal result = builder.create<arith::SubFOp>(loc, result, zeroPoint); 375*852b6486SRafael Ubal } 376*852b6486SRafael Ubal 377*852b6486SRafael Ubal // Multiply by scale 378*852b6486SRafael Ubal result = builder.create<arith::MulFOp>(loc, result, scale); 379*852b6486SRafael Ubal return result; 380*852b6486SRafael Ubal } 381*852b6486SRafael Ubal 382*852b6486SRafael Ubal // Convert a scalar or ranked tensor input with the given scale and zero point 383*852b6486SRafael Ubal // values. 384*852b6486SRafael Ubal // 385*852b6486SRafael Ubal // - input 386*852b6486SRafael Ubal // Scalar or ranked tensor value. 387*852b6486SRafael Ubal // 388*852b6486SRafael Ubal // - inputShape 389*852b6486SRafael Ubal // If 'input' is a tensor, combination or attributes/values representing its 390*852b6486SRafael Ubal // static/dynamic dimensions. If 'input' is a scalar, empty list. 391*852b6486SRafael Ubal // 392*852b6486SRafael Ubal // - scale 393*852b6486SRafael Ubal // Scale as a floating-point scalar value. 394*852b6486SRafael Ubal // 395*852b6486SRafael Ubal // - zeroPoint 396*852b6486SRafael Ubal // Zero point as an integer scalar value. 397*852b6486SRafael Ubal // 398*852b6486SRafael Ubal // - quantizedType 399*852b6486SRafael Ubal // Scalar quantized type of the result ('quant.qcast') or of the input 400*852b6486SRafael Ubal // ('quant.dcast'). 401*852b6486SRafael Ubal // 402*852b6486SRafael Ubal Value convertRanked(OpBuilder &builder, Location loc, Operation *op, 403*852b6486SRafael Ubal Value input, ArrayRef<OpFoldResult> inputShape, Value scale, 404*852b6486SRafael Ubal Value zeroPoint, QuantizedType quantizedType) { 405*852b6486SRafael Ubal if (isa<QuantizeCastOp>(op)) 406*852b6486SRafael Ubal return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, 407*852b6486SRafael Ubal quantizedType); 408*852b6486SRafael Ubal if (isa<DequantizeCastOp>(op)) 409*852b6486SRafael Ubal return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint, 410*852b6486SRafael Ubal quantizedType); 411*852b6486SRafael Ubal llvm_unreachable("unexpected quant op"); 412*852b6486SRafael Ubal } 413*852b6486SRafael Ubal 414*852b6486SRafael Ubal // Convert an operation using per-layer quantization with a scalar or ranked 415*852b6486SRafael Ubal // tensor input. 416*852b6486SRafael Ubal // 417*852b6486SRafael Ubal // - op 418*852b6486SRafael Ubal // 'quant.dcast' or 'quant.qcast' op. 419*852b6486SRafael Ubal // 420*852b6486SRafael Ubal // - input 421*852b6486SRafael Ubal // Scalar or ranked tensor. 422*852b6486SRafael Ubal // 423*852b6486SRafael Ubal // - quantizedType 424*852b6486SRafael Ubal // Per-layer quantized type. 425*852b6486SRafael Ubal // 426*852b6486SRafael Ubal Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, 427*852b6486SRafael Ubal Value input, UniformQuantizedType quantizedType) { 428*852b6486SRafael Ubal // Create scale and zero point constants 429*852b6486SRafael Ubal auto expressedType = quantizedType.getExpressedType(); 430*852b6486SRafael Ubal auto storageType = quantizedType.getStorageType(); 431*852b6486SRafael Ubal auto scaleAttr = 432*852b6486SRafael Ubal builder.getFloatAttr(expressedType, quantizedType.getScale()); 433*852b6486SRafael Ubal auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr); 434*852b6486SRafael Ubal auto zeroPointAttr = 435*852b6486SRafael Ubal builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); 436*852b6486SRafael Ubal auto zeroPoint = 437*852b6486SRafael Ubal builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr); 438*852b6486SRafael Ubal 439*852b6486SRafael Ubal auto inputShape = getScalarOrTensorShape(builder, loc, input); 440*852b6486SRafael Ubal return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, 441*852b6486SRafael Ubal quantizedType); 442*852b6486SRafael Ubal } 443*852b6486SRafael Ubal 444*852b6486SRafael Ubal // Convert an operation using per-layer quantization. 445*852b6486SRafael Ubal // 446*852b6486SRafael Ubal // - op 447*852b6486SRafael Ubal // 'quant.dcast' or 'quant.qcast' op. 448*852b6486SRafael Ubal // 449*852b6486SRafael Ubal // - input 450*852b6486SRafael Ubal // Scalar, ranked tensor, or unranked tensor. 451*852b6486SRafael Ubal // 452*852b6486SRafael Ubal // - quantizedType 453*852b6486SRafael Ubal // Per-layer quantized type. 454*852b6486SRafael Ubal // 455*852b6486SRafael Ubal Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, 456*852b6486SRafael Ubal Value input, UniformQuantizedType quantizedType) { 457*852b6486SRafael Ubal // Flatten input if unranked 458*852b6486SRafael Ubal bool isUnranked = isa<UnrankedTensorType>(input.getType()); 459*852b6486SRafael Ubal Value inputShape; 460*852b6486SRafael Ubal if (isUnranked) 461*852b6486SRafael Ubal std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); 462*852b6486SRafael Ubal 463*852b6486SRafael Ubal // Process ranked tensor 464*852b6486SRafael Ubal auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType); 465*852b6486SRafael Ubal 466*852b6486SRafael Ubal // Restore original shape if unranked 467*852b6486SRafael Ubal if (isUnranked) 468*852b6486SRafael Ubal result = restoreUnrankedTensorShape(builder, loc, result, inputShape); 469*852b6486SRafael Ubal 470*852b6486SRafael Ubal return result; 471*852b6486SRafael Ubal } 472*852b6486SRafael Ubal 473*852b6486SRafael Ubal // Convert an operation using per-channel quantization and a scalar or ranked 474*852b6486SRafael Ubal // tensor as an input. 475*852b6486SRafael Ubal // 476*852b6486SRafael Ubal // - op 477*852b6486SRafael Ubal // 'quant.dcast' or 'quant.qcast' op. 478*852b6486SRafael Ubal // 479*852b6486SRafael Ubal // - input 480*852b6486SRafael Ubal // Scalar or ranked tensor. 481*852b6486SRafael Ubal // 482*852b6486SRafael Ubal // - quantizedType 483*852b6486SRafael Ubal // Per-channel quantized type. 484*852b6486SRafael Ubal // 485*852b6486SRafael Ubal Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, 486*852b6486SRafael Ubal Value input, 487*852b6486SRafael Ubal UniformQuantizedPerAxisType quantizedType, 488*852b6486SRafael Ubal int64_t channelAxis) { 489*852b6486SRafael Ubal auto *context = builder.getContext(); 490*852b6486SRafael Ubal 491*852b6486SRafael Ubal auto inputType = cast<RankedTensorType>(input.getType()); 492*852b6486SRafael Ubal auto inputRank = inputType.getRank(); 493*852b6486SRafael Ubal 494*852b6486SRafael Ubal auto scales = materializePerChannelScales(builder, loc, quantizedType); 495*852b6486SRafael Ubal auto zeroPoints = 496*852b6486SRafael Ubal materializePerChannelZeroPoints(builder, loc, quantizedType); 497*852b6486SRafael Ubal 498*852b6486SRafael Ubal auto elementType = isa<FloatType>(inputType.getElementType()) 499*852b6486SRafael Ubal ? quantizedType.getStorageType() 500*852b6486SRafael Ubal : quantizedType.getExpressedType(); 501*852b6486SRafael Ubal auto initShape = tensor::getMixedSizes(builder, loc, input); 502*852b6486SRafael Ubal Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType); 503*852b6486SRafael Ubal 504*852b6486SRafael Ubal SmallVector<utils::IteratorType> iteratorTypes( 505*852b6486SRafael Ubal inputRank, utils::IteratorType::parallel); 506*852b6486SRafael Ubal auto channelAxisAffineMap = AffineMap::get( 507*852b6486SRafael Ubal inputRank, 0, builder.getAffineDimExpr(channelAxis), context); 508*852b6486SRafael Ubal SmallVector<AffineMap> indexingMaps{ 509*852b6486SRafael Ubal builder.getMultiDimIdentityMap(inputRank), 510*852b6486SRafael Ubal channelAxisAffineMap, 511*852b6486SRafael Ubal channelAxisAffineMap, 512*852b6486SRafael Ubal builder.getMultiDimIdentityMap(inputRank) 513*852b6486SRafael Ubal }; 514*852b6486SRafael Ubal auto result = builder.create<linalg::GenericOp>( 515*852b6486SRafael Ubal loc, 516*852b6486SRafael Ubal init.getType(), // resultType 517*852b6486SRafael Ubal ValueRange{input, scales, zeroPoints}, // inputs 518*852b6486SRafael Ubal ValueRange{init}, // outputs 519*852b6486SRafael Ubal indexingMaps, 520*852b6486SRafael Ubal iteratorTypes, 521*852b6486SRafael Ubal [&](OpBuilder& builder, Location loc, ValueRange args) { 522*852b6486SRafael Ubal assert(args.size() == 4); 523*852b6486SRafael Ubal auto input = args[0]; 524*852b6486SRafael Ubal auto scale = args[1]; 525*852b6486SRafael Ubal auto zeroPoint = args[2]; 526*852b6486SRafael Ubal 527*852b6486SRafael Ubal auto result = convertRanked(builder, loc, op, input, {}, scale, 528*852b6486SRafael Ubal zeroPoint, quantizedType); 529*852b6486SRafael Ubal 530*852b6486SRafael Ubal builder.create<linalg::YieldOp>(loc, result); 531*852b6486SRafael Ubal }) 532*852b6486SRafael Ubal .getResult(0); 533*852b6486SRafael Ubal 534*852b6486SRafael Ubal return result; 535*852b6486SRafael Ubal } 536*852b6486SRafael Ubal 537*852b6486SRafael Ubal // Convert an operation using per-channel quantization. 538*852b6486SRafael Ubal // 539*852b6486SRafael Ubal // - op 540*852b6486SRafael Ubal // 'quant.dcast' or 'quant.qcast' op. 541*852b6486SRafael Ubal // 542*852b6486SRafael Ubal // - input 543*852b6486SRafael Ubal // Scalar, ranked tensor, or unranked tensor. 544*852b6486SRafael Ubal // 545*852b6486SRafael Ubal // - quantizedType 546*852b6486SRafael Ubal // Per-channel quantized type. 547*852b6486SRafael Ubal // 548*852b6486SRafael Ubal Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, 549*852b6486SRafael Ubal Value input, 550*852b6486SRafael Ubal UniformQuantizedPerAxisType quantizedType) { 551*852b6486SRafael Ubal // Flatten unranked tensor into a 3D ranked tensor if necessary 552*852b6486SRafael Ubal bool isUnranked = isa<UnrankedTensorType>(input.getType()); 553*852b6486SRafael Ubal int64_t channelAxis = quantizedType.getQuantizedDimension(); 554*852b6486SRafael Ubal int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); 555*852b6486SRafael Ubal Value inputShape; 556*852b6486SRafael Ubal if (isUnranked) { 557*852b6486SRafael Ubal std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( 558*852b6486SRafael Ubal builder, loc, input, channelAxis, channelAxisSize); 559*852b6486SRafael Ubal channelAxis = 1; 560*852b6486SRafael Ubal } 561*852b6486SRafael Ubal 562*852b6486SRafael Ubal // Work on a ranked tensor 563*852b6486SRafael Ubal auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType, 564*852b6486SRafael Ubal channelAxis); 565*852b6486SRafael Ubal 566*852b6486SRafael Ubal // Restore original tensor shape if unranked 567*852b6486SRafael Ubal if (isUnranked) 568*852b6486SRafael Ubal result = restoreUnrankedTensorShape(builder, loc, result, inputShape); 569*852b6486SRafael Ubal 570*852b6486SRafael Ubal return result; 571*852b6486SRafael Ubal } 572*852b6486SRafael Ubal 573*852b6486SRafael Ubal // Convert a quantization operation. 574*852b6486SRafael Ubal // 575*852b6486SRafael Ubal // - op 576*852b6486SRafael Ubal // 'quant.dcast' or 'quant.qcast' op. 577*852b6486SRafael Ubal // 578*852b6486SRafael Ubal // - input 579*852b6486SRafael Ubal // Scalar, ranked tensor, or unranked tensor. The element type matches 580*852b6486SRafael Ubal // the storage type (quant.dcast) or expressed type (quant.qcast) of 581*852b6486SRafael Ubal // 'quantizedType'. 582*852b6486SRafael Ubal // 583*852b6486SRafael Ubal // - quantizedType 584*852b6486SRafael Ubal // Per-layer or per-channel quantized type. 585*852b6486SRafael Ubal // 586*852b6486SRafael Ubal Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, 587*852b6486SRafael Ubal Value input, Type quantizedType) { 588*852b6486SRafael Ubal if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType)) 589*852b6486SRafael Ubal return convertPerLayer(builder, loc, op, input, uniformQuantizedType); 590*852b6486SRafael Ubal 591*852b6486SRafael Ubal if (auto uniformQuantizedPerAxisType = 592*852b6486SRafael Ubal dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) 593*852b6486SRafael Ubal return convertPerChannel(builder, loc, op, input, 594*852b6486SRafael Ubal uniformQuantizedPerAxisType); 595*852b6486SRafael Ubal 596*852b6486SRafael Ubal llvm_unreachable("unexpected quantized type"); 597*852b6486SRafael Ubal } 598*852b6486SRafael Ubal 599*852b6486SRafael Ubal // Lowering pattern for 'quant.dcast' 600*852b6486SRafael Ubal struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> { 601*852b6486SRafael Ubal using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern; 602*852b6486SRafael Ubal 603*852b6486SRafael Ubal LogicalResult 604*852b6486SRafael Ubal matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, 605*852b6486SRafael Ubal ConversionPatternRewriter &rewriter) const override { 606*852b6486SRafael Ubal auto loc = op.getLoc(); 607*852b6486SRafael Ubal auto input = op.getInput(); 608*852b6486SRafael Ubal auto quantizedType = 609*852b6486SRafael Ubal cast<QuantizedType>(getScalarType(op.getInput().getType())); 610*852b6486SRafael Ubal 611*852b6486SRafael Ubal // Convert quantized input to storage type 612*852b6486SRafael Ubal auto storageScalarOrTensorType = 613*852b6486SRafael Ubal getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); 614*852b6486SRafael Ubal input = rewriter.create<quant::StorageCastOp>( 615*852b6486SRafael Ubal loc, storageScalarOrTensorType, input); 616*852b6486SRafael Ubal 617*852b6486SRafael Ubal auto result = convertQuantized(rewriter, loc, op, input, quantizedType); 618*852b6486SRafael Ubal 619*852b6486SRafael Ubal rewriter.replaceOp(op, result); 620*852b6486SRafael Ubal return success(); 621*852b6486SRafael Ubal } 622*852b6486SRafael Ubal }; 623*852b6486SRafael Ubal 624*852b6486SRafael Ubal // Lowering pattern for 'quant.qcast' 625*852b6486SRafael Ubal struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> { 626*852b6486SRafael Ubal using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern; 627*852b6486SRafael Ubal 628*852b6486SRafael Ubal LogicalResult 629*852b6486SRafael Ubal matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor, 630*852b6486SRafael Ubal ConversionPatternRewriter &rewriter) const override { 631*852b6486SRafael Ubal auto loc = op.getLoc(); 632*852b6486SRafael Ubal auto input = op.getInput(); 633*852b6486SRafael Ubal auto quantizedType = getScalarType(op.getResult().getType()); 634*852b6486SRafael Ubal 635*852b6486SRafael Ubal // Flatten unranked tensor input 636*852b6486SRafael Ubal auto result = convertQuantized(rewriter, loc, op, input, quantizedType); 637*852b6486SRafael Ubal 638*852b6486SRafael Ubal // Cast stored value to result quantized value 639*852b6486SRafael Ubal rewriter.replaceOpWithNewOp<quant::StorageCastOp>( 640*852b6486SRafael Ubal op, op.getResult().getType(), result); 641*852b6486SRafael Ubal return success(); 642*852b6486SRafael Ubal } 643*852b6486SRafael Ubal }; 644*852b6486SRafael Ubal 645*852b6486SRafael Ubal struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> { 646*852b6486SRafael Ubal void runOnOperation() override { 647*852b6486SRafael Ubal RewritePatternSet patterns(&getContext()); 648*852b6486SRafael Ubal populateLowerQuantOpsPatterns(patterns); 649*852b6486SRafael Ubal 650*852b6486SRafael Ubal ConversionTarget target(getContext()); 651*852b6486SRafael Ubal target.addLegalOp<quant::StorageCastOp>(); 652*852b6486SRafael Ubal target.addIllegalDialect<quant::QuantDialect>(); 653*852b6486SRafael Ubal target.addLegalDialect< 654*852b6486SRafael Ubal arith::ArithDialect, 655*852b6486SRafael Ubal linalg::LinalgDialect, 656*852b6486SRafael Ubal shape::ShapeDialect, 657*852b6486SRafael Ubal tensor::TensorDialect 658*852b6486SRafael Ubal >(); 659*852b6486SRafael Ubal 660*852b6486SRafael Ubal if (failed(applyPartialConversion(getOperation(), target, 661*852b6486SRafael Ubal std::move(patterns)))) 662*852b6486SRafael Ubal signalPassFailure(); 663*852b6486SRafael Ubal } 664*852b6486SRafael Ubal }; 665*852b6486SRafael Ubal 666*852b6486SRafael Ubal } // namespace 667*852b6486SRafael Ubal 668*852b6486SRafael Ubal void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { 669*852b6486SRafael Ubal patterns.add< 670*852b6486SRafael Ubal DequantizeCastOpConversion, 671*852b6486SRafael Ubal QuantizeCastOpConversion 672*852b6486SRafael Ubal >(patterns.getContext()); 673*852b6486SRafael Ubal } 674*852b6486SRafael Ubal 675*852b6486SRafael Ubal } // namespace quant 676*852b6486SRafael Ubal } // namespace mlir 677