xref: /llvm-project/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1*852b6486SRafael Ubal //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2*852b6486SRafael Ubal //
3*852b6486SRafael Ubal // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*852b6486SRafael Ubal // See https://llvm.org/LICENSE.txt for license information.
5*852b6486SRafael Ubal // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*852b6486SRafael Ubal //
7*852b6486SRafael Ubal //===----------------------------------------------------------------------===//
8*852b6486SRafael Ubal //
9*852b6486SRafael Ubal // Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10*852b6486SRafael Ubal //
11*852b6486SRafael Ubal //===----------------------------------------------------------------------===//
12*852b6486SRafael Ubal 
13*852b6486SRafael Ubal #include "mlir/Dialect/Arith/IR/Arith.h"
14*852b6486SRafael Ubal #include "mlir/Dialect/Func/IR/FuncOps.h"
15*852b6486SRafael Ubal #include "mlir/Dialect/Linalg/IR/Linalg.h"
16*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/Quant.h"
17*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/QuantTypes.h"
18*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Transforms/Passes.h"
19*852b6486SRafael Ubal #include "mlir/Dialect/Shape/IR/Shape.h"
20*852b6486SRafael Ubal #include "mlir/Dialect/Tensor/IR/Tensor.h"
21*852b6486SRafael Ubal #include "mlir/IR/Matchers.h"
22*852b6486SRafael Ubal #include "mlir/IR/PatternMatch.h"
23*852b6486SRafael Ubal #include "mlir/Transforms/DialectConversion.h"
24*852b6486SRafael Ubal 
25*852b6486SRafael Ubal namespace mlir {
26*852b6486SRafael Ubal namespace quant {
27*852b6486SRafael Ubal 
28*852b6486SRafael Ubal #define GEN_PASS_DEF_LOWERQUANTOPS
29*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30*852b6486SRafael Ubal 
31*852b6486SRafael Ubal namespace {
32*852b6486SRafael Ubal 
33*852b6486SRafael Ubal // If 'inputType' is a tensor, return its element type. If it is a scalar,
34*852b6486SRafael Ubal // return it as is.
35*852b6486SRafael Ubal Type getScalarType(Type inputType) {
36*852b6486SRafael Ubal   if (auto tensorType = dyn_cast<TensorType>(inputType))
37*852b6486SRafael Ubal     return tensorType.getElementType();
38*852b6486SRafael Ubal   return inputType;
39*852b6486SRafael Ubal }
40*852b6486SRafael Ubal 
41*852b6486SRafael Ubal // Return the shape of an input value as a list of attributes (static dimensions)
42*852b6486SRafael Ubal // and values (dynamic dimensions). If 'input' is a scalar, an empty list is
43*852b6486SRafael Ubal // returned. If 'input' is a tensor, its shape is returned.
44*852b6486SRafael Ubal SmallVector<OpFoldResult>
45*852b6486SRafael Ubal getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
46*852b6486SRafael Ubal   if (isa<TensorType>(input.getType()))
47*852b6486SRafael Ubal     return tensor::getMixedSizes(builder, loc, input);
48*852b6486SRafael Ubal   return {};
49*852b6486SRafael Ubal }
50*852b6486SRafael Ubal 
51*852b6486SRafael Ubal // If 'referenceType' is a scalar, return 'elementType' as is. If
52*852b6486SRafael Ubal // 'referenceType' is a tensor, return another tensor with the same shape and
53*852b6486SRafael Ubal // elements of type 'elementType'.
54*852b6486SRafael Ubal Type getScalarOrTensorType(Type elementType, Type referenceType) {
55*852b6486SRafael Ubal   if (auto tensorType = dyn_cast<TensorType>(referenceType))
56*852b6486SRafael Ubal     return tensorType.clone(elementType);
57*852b6486SRafael Ubal   return elementType;
58*852b6486SRafael Ubal }
59*852b6486SRafael Ubal 
60*852b6486SRafael Ubal // Return a constant with the given value. If 'referenceType' is a tensor, a
61*852b6486SRafael Ubal // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62*852b6486SRafael Ubal // scalar, 'referenceShape' is ignored and a scalar constant is returned.
63*852b6486SRafael Ubal Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64*852b6486SRafael Ubal                                 Type referenceType,
65*852b6486SRafael Ubal                                 ArrayRef<OpFoldResult> referenceShape) {
66*852b6486SRafael Ubal   // If the result type is a scalar, return the unmodified scalar constant.
67*852b6486SRafael Ubal   auto tensorType = dyn_cast<TensorType>(referenceType);
68*852b6486SRafael Ubal   if (!tensorType) {
69*852b6486SRafael Ubal     assert(referenceShape.empty());
70*852b6486SRafael Ubal     return scalar;
71*852b6486SRafael Ubal   }
72*852b6486SRafael Ubal 
73*852b6486SRafael Ubal   // Create tensor splat
74*852b6486SRafael Ubal   auto tensorConstant =
75*852b6486SRafael Ubal       builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76*852b6486SRafael Ubal   return tensorConstant;
77*852b6486SRafael Ubal }
78*852b6486SRafael Ubal 
79*852b6486SRafael Ubal // Reshape an unranked tensor into a 1D ranked tensor.
80*852b6486SRafael Ubal //
81*852b6486SRafael Ubal // - input
82*852b6486SRafael Ubal //   Unranked tensor.
83*852b6486SRafael Ubal //
84*852b6486SRafael Ubal // Return values:
85*852b6486SRafael Ubal //
86*852b6486SRafael Ubal // - flatInput
87*852b6486SRafael Ubal //   1D ranked, dynamically shaped tensor.
88*852b6486SRafael Ubal //
89*852b6486SRafael Ubal // - inputShape
90*852b6486SRafael Ubal //   1D extent tensor containing the shape of the original unranked input.
91*852b6486SRafael Ubal //
92*852b6486SRafael Ubal std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93*852b6486SRafael Ubal                                               Value input) {
94*852b6486SRafael Ubal   // Get unranked input shape and total size
95*852b6486SRafael Ubal   auto *context = builder.getContext();
96*852b6486SRafael Ubal   auto shapeType = shape::getExtentTensorType(context);
97*852b6486SRafael Ubal   auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98*852b6486SRafael Ubal   Value inputSize = builder.create<shape::NumElementsOp>(
99*852b6486SRafael Ubal       loc, builder.getIndexType(), inputShape);
100*852b6486SRafael Ubal 
101*852b6486SRafael Ubal   // Turn input size into 1D tensor
102*852b6486SRafael Ubal   auto flatShapeType = shape::getExtentTensorType(context, 1);
103*852b6486SRafael Ubal   auto flatInputShape = builder.create<tensor::FromElementsOp>(
104*852b6486SRafael Ubal       loc, flatShapeType, inputSize);
105*852b6486SRafael Ubal 
106*852b6486SRafael Ubal   // Reshape input tensor into 1D
107*852b6486SRafael Ubal   auto inputType = cast<UnrankedTensorType>(input.getType());
108*852b6486SRafael Ubal   auto elementType = inputType.getElementType();
109*852b6486SRafael Ubal   auto flatInputType =
110*852b6486SRafael Ubal       RankedTensorType::get({ShapedType::kDynamic}, elementType);
111*852b6486SRafael Ubal   auto flatInput = builder.create<tensor::ReshapeOp>(
112*852b6486SRafael Ubal       loc, flatInputType, input, flatInputShape);
113*852b6486SRafael Ubal   return std::make_pair(flatInput, inputShape);
114*852b6486SRafael Ubal }
115*852b6486SRafael Ubal 
116*852b6486SRafael Ubal // Reshape an unranked tensor into a 3D ranked tensor where the central
117*852b6486SRafael Ubal // dimension of the result tensor corresponds to dimension 'axis' of the input
118*852b6486SRafael Ubal // tensor.
119*852b6486SRafael Ubal //
120*852b6486SRafael Ubal // - input
121*852b6486SRafael Ubal //   Unranked tensor.
122*852b6486SRafael Ubal //
123*852b6486SRafael Ubal // - axis
124*852b6486SRafael Ubal //   Index of the input dimension around which other input dimiensions will be
125*852b6486SRafael Ubal //   collapsed.
126*852b6486SRafael Ubal //
127*852b6486SRafael Ubal // - axisSize
128*852b6486SRafael Ubal //   Size of input dimension 'axis'.
129*852b6486SRafael Ubal //
130*852b6486SRafael Ubal // Return values:
131*852b6486SRafael Ubal //
132*852b6486SRafael Ubal // - flatInput
133*852b6486SRafael Ubal //   3D ranked tensor of shape [?, axisSize, ?].
134*852b6486SRafael Ubal //
135*852b6486SRafael Ubal // - inputShape
136*852b6486SRafael Ubal //   1D extent tensor containing the shape of the original unranked input.
137*852b6486SRafael Ubal //
138*852b6486SRafael Ubal std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
139*852b6486SRafael Ubal                                                         Location loc,
140*852b6486SRafael Ubal                                                         Value input,
141*852b6486SRafael Ubal                                                         int64_t axis,
142*852b6486SRafael Ubal                                                         int64_t axisSize) {
143*852b6486SRafael Ubal   // Get full tensor shape
144*852b6486SRafael Ubal   auto *context = builder.getContext();
145*852b6486SRafael Ubal   auto indexType = builder.getIndexType();
146*852b6486SRafael Ubal   auto shapeType = shape::getExtentTensorType(context);
147*852b6486SRafael Ubal   auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
148*852b6486SRafael Ubal 
149*852b6486SRafael Ubal   // Get shape and sizes on left and right of axis
150*852b6486SRafael Ubal   auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
151*852b6486SRafael Ubal   auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
152*852b6486SRafael Ubal   auto shapeLeft = builder.create<shape::SplitAtOp>(
153*852b6486SRafael Ubal       loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
154*852b6486SRafael Ubal       .getResult(0);
155*852b6486SRafael Ubal   auto sizeLeft = builder.create<shape::NumElementsOp>(
156*852b6486SRafael Ubal       loc, indexType, shapeLeft);
157*852b6486SRafael Ubal   auto shapeRight = builder.create<shape::SplitAtOp>(
158*852b6486SRafael Ubal       loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
159*852b6486SRafael Ubal       .getResult(1);
160*852b6486SRafael Ubal   auto sizeRight = builder.create<shape::NumElementsOp>(
161*852b6486SRafael Ubal       loc, indexType, shapeRight);
162*852b6486SRafael Ubal 
163*852b6486SRafael Ubal   // Compute flat input shape as a 3-element 1D tensor
164*852b6486SRafael Ubal   auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
165*852b6486SRafael Ubal   auto flatShapeType = shape::getExtentTensorType(context, 3);
166*852b6486SRafael Ubal   auto flatInputShape = builder.create<tensor::FromElementsOp>(
167*852b6486SRafael Ubal       loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
168*852b6486SRafael Ubal 
169*852b6486SRafael Ubal   // Reshape input to 3D tensor
170*852b6486SRafael Ubal   auto inputType = cast<UnrankedTensorType>(input.getType());
171*852b6486SRafael Ubal   auto elementType = inputType.getElementType();
172*852b6486SRafael Ubal   auto flatInputType = RankedTensorType::get(
173*852b6486SRafael Ubal       {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
174*852b6486SRafael Ubal   auto flatInput = builder.create<tensor::ReshapeOp>(
175*852b6486SRafael Ubal       loc, flatInputType, input, flatInputShape);
176*852b6486SRafael Ubal 
177*852b6486SRafael Ubal   return std::make_pair(flatInput, inputShape);
178*852b6486SRafael Ubal }
179*852b6486SRafael Ubal 
180*852b6486SRafael Ubal // Reshape an input tensor into its original unranked shape.
181*852b6486SRafael Ubal //
182*852b6486SRafael Ubal // - input
183*852b6486SRafael Ubal //   Ranked tensor.
184*852b6486SRafael Ubal //
185*852b6486SRafael Ubal // - inputShape
186*852b6486SRafael Ubal //   1D extent tensor.
187*852b6486SRafael Ubal //
188*852b6486SRafael Ubal Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
189*852b6486SRafael Ubal                                  Value inputShape) {
190*852b6486SRafael Ubal   auto inputType = cast<RankedTensorType>(input.getType());
191*852b6486SRafael Ubal   auto elementType = inputType.getElementType();
192*852b6486SRafael Ubal   auto unrankedType = UnrankedTensorType::get(elementType);
193*852b6486SRafael Ubal   return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
194*852b6486SRafael Ubal }
195*852b6486SRafael Ubal 
196*852b6486SRafael Ubal // Create a tensor constant containing all scales in a per-channel quantized
197*852b6486SRafael Ubal // type. Example:
198*852b6486SRafael Ubal //
199*852b6486SRafael Ubal //   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
200*852b6486SRafael Ubal //
201*852b6486SRafael Ubal // produces
202*852b6486SRafael Ubal //
203*852b6486SRafael Ubal //   %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
204*852b6486SRafael Ubal //
205*852b6486SRafael Ubal Value materializePerChannelScales(OpBuilder &builder, Location loc,
206*852b6486SRafael Ubal                                   UniformQuantizedPerAxisType quantizedType) {
207*852b6486SRafael Ubal   auto scales = quantizedType.getScales();
208*852b6486SRafael Ubal   auto expressedType = quantizedType.getExpressedType();
209*852b6486SRafael Ubal   auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
210*852b6486SRafael Ubal     return builder.getFloatAttr(expressedType, scale);
211*852b6486SRafael Ubal   });
212*852b6486SRafael Ubal   auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
213*852b6486SRafael Ubal   auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
214*852b6486SRafael Ubal   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
215*852b6486SRafael Ubal }
216*852b6486SRafael Ubal 
217*852b6486SRafael Ubal // Create a tensor constant containing all zero points in a per-channel
218*852b6486SRafael Ubal // quantized type. Example:
219*852b6486SRafael Ubal //
220*852b6486SRafael Ubal //   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
221*852b6486SRafael Ubal //
222*852b6486SRafael Ubal // produces
223*852b6486SRafael Ubal //
224*852b6486SRafael Ubal //   %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
225*852b6486SRafael Ubal //
226*852b6486SRafael Ubal Value materializePerChannelZeroPoints(
227*852b6486SRafael Ubal     OpBuilder &builder, Location loc,
228*852b6486SRafael Ubal     UniformQuantizedPerAxisType quantizedType) {
229*852b6486SRafael Ubal   auto zeroPoints = quantizedType.getZeroPoints();
230*852b6486SRafael Ubal   auto storageType = quantizedType.getStorageType();
231*852b6486SRafael Ubal   auto zeroPointAttrs = llvm::map_to_vector(
232*852b6486SRafael Ubal       zeroPoints,
233*852b6486SRafael Ubal       [&](int64_t zeroPoint) -> Attribute {
234*852b6486SRafael Ubal         return builder.getIntegerAttr(storageType, zeroPoint);
235*852b6486SRafael Ubal       });
236*852b6486SRafael Ubal   auto tensorType =
237*852b6486SRafael Ubal       RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
238*852b6486SRafael Ubal   auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
239*852b6486SRafael Ubal   return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
240*852b6486SRafael Ubal }
241*852b6486SRafael Ubal 
242*852b6486SRafael Ubal // Clamp the given scalar or tensor input using the storage bounds encoded in
243*852b6486SRafael Ubal // the given quantized type, if present.
244*852b6486SRafael Ubal //
245*852b6486SRafael Ubal // - input
246*852b6486SRafael Ubal //   Scalar or ranked tensor input. The element type must match the storage type
247*852b6486SRafael Ubal //   of 'quantizedType'.
248*852b6486SRafael Ubal //
249*852b6486SRafael Ubal // - inputShape
250*852b6486SRafael Ubal //   If 'input' is a tensor, combination of attributes/values representing its
251*852b6486SRafael Ubal //   static/dynamic dimensions. If 'input' is a scalar, empty list.
252*852b6486SRafael Ubal //
253*852b6486SRafael Ubal // - quantizedType
254*852b6486SRafael Ubal //   Per-axis or per-channel quantized type.
255*852b6486SRafael Ubal Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
256*852b6486SRafael Ubal                           ArrayRef<OpFoldResult> inputShape,
257*852b6486SRafael Ubal                           QuantizedType quantizedType) {
258*852b6486SRafael Ubal   // If quantized type does not narrow down the storage type range, there is
259*852b6486SRafael Ubal   // nothing to do.
260*852b6486SRafael Ubal   if (!quantizedType.hasStorageTypeBounds())
261*852b6486SRafael Ubal     return input;
262*852b6486SRafael Ubal 
263*852b6486SRafael Ubal   // Materialize bounds
264*852b6486SRafael Ubal   auto inputType = input.getType();
265*852b6486SRafael Ubal   auto storageType = quantizedType.getStorageType();
266*852b6486SRafael Ubal   auto storageMinScalar = builder.create<arith::ConstantIntOp>(
267*852b6486SRafael Ubal       loc, quantizedType.getStorageTypeMin(), storageType);
268*852b6486SRafael Ubal   auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
269*852b6486SRafael Ubal       loc, quantizedType.getStorageTypeMax(), storageType);
270*852b6486SRafael Ubal   auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
271*852b6486SRafael Ubal                                               inputType, inputShape);
272*852b6486SRafael Ubal   auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
273*852b6486SRafael Ubal                                               inputType, inputShape);
274*852b6486SRafael Ubal 
275*852b6486SRafael Ubal   // Clamp
276*852b6486SRafael Ubal   if (quantizedType.isSigned()) {
277*852b6486SRafael Ubal     input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
278*852b6486SRafael Ubal     input = builder.create<arith::MinSIOp>(loc, input, storageMax);
279*852b6486SRafael Ubal   } else {
280*852b6486SRafael Ubal     input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
281*852b6486SRafael Ubal     input = builder.create<arith::MinUIOp>(loc, input, storageMax);
282*852b6486SRafael Ubal   }
283*852b6486SRafael Ubal   return input;
284*852b6486SRafael Ubal }
285*852b6486SRafael Ubal 
286*852b6486SRafael Ubal // Emit op 'arith.fptosi' or 'arith.fptoui'.
287*852b6486SRafael Ubal Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
288*852b6486SRafael Ubal                             Type resultType, bool isSigned) {
289*852b6486SRafael Ubal   if (isSigned)
290*852b6486SRafael Ubal     return builder.create<arith::FPToSIOp>(loc, resultType, input);
291*852b6486SRafael Ubal   return builder.create<arith::FPToUIOp>(loc, resultType, input);
292*852b6486SRafael Ubal }
293*852b6486SRafael Ubal 
294*852b6486SRafael Ubal // Emit op 'arith.sitofp' or 'arith.uitofp'.
295*852b6486SRafael Ubal Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
296*852b6486SRafael Ubal                             Type resultType, bool isSigned) {
297*852b6486SRafael Ubal   if (isSigned)
298*852b6486SRafael Ubal     return builder.create<arith::SIToFPOp>(loc, resultType, input);
299*852b6486SRafael Ubal   return builder.create<arith::UIToFPOp>(loc, resultType, input);
300*852b6486SRafael Ubal }
301*852b6486SRafael Ubal 
302*852b6486SRafael Ubal // Quantize a scalar or ranked tensor value. The stored value is clamped using
303*852b6486SRafael Ubal // the storage bounds encoded in the given quantized type.
304*852b6486SRafael Ubal //
305*852b6486SRafael Ubal // See function 'convertRanked()' below for a description of the arguments.
306*852b6486SRafael Ubal Value quantizeValue(OpBuilder &builder, Location loc, Value input,
307*852b6486SRafael Ubal                     ArrayRef<OpFoldResult> inputShape, Value scale,
308*852b6486SRafael Ubal                     Value zeroPoint, QuantizedType quantizedType) {
309*852b6486SRafael Ubal   // Convert scale to tensor if necessary
310*852b6486SRafael Ubal   auto inputType = input.getType();
311*852b6486SRafael Ubal   scale = getScalarOrTensorConstant(
312*852b6486SRafael Ubal       builder, loc, scale, inputType, inputShape);
313*852b6486SRafael Ubal 
314*852b6486SRafael Ubal   // Scale input
315*852b6486SRafael Ubal   auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
316*852b6486SRafael Ubal 
317*852b6486SRafael Ubal   // Skip unnecessary computations if no zero point is given
318*852b6486SRafael Ubal   Value storedValueFloat = scaledValue;
319*852b6486SRafael Ubal   if (!matchPattern(zeroPoint, m_Zero())) {
320*852b6486SRafael Ubal     // Convert zero point to tensor if necessary
321*852b6486SRafael Ubal     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
322*852b6486SRafael Ubal                                           inputShape);
323*852b6486SRafael Ubal 
324*852b6486SRafael Ubal     // Convert zero point from storage to expressed type
325*852b6486SRafael Ubal     zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
326*852b6486SRafael Ubal                                       scale.getType(),
327*852b6486SRafael Ubal                                       quantizedType.isSigned());
328*852b6486SRafael Ubal 
329*852b6486SRafael Ubal     // Add zero point to stored value
330*852b6486SRafael Ubal     storedValueFloat =
331*852b6486SRafael Ubal         builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
332*852b6486SRafael Ubal   }
333*852b6486SRafael Ubal 
334*852b6486SRafael Ubal   // Convert stored value to storage type
335*852b6486SRafael Ubal   auto storageScalarOrTensorType =
336*852b6486SRafael Ubal       getScalarOrTensorType(quantizedType.getStorageType(), inputType);
337*852b6486SRafael Ubal   auto storedValueInt = convertFloatToInteger(
338*852b6486SRafael Ubal       builder, loc, storedValueFloat, storageScalarOrTensorType,
339*852b6486SRafael Ubal       quantizedType.isSigned());
340*852b6486SRafael Ubal 
341*852b6486SRafael Ubal   // Clamp stored value it if the storage type is bound
342*852b6486SRafael Ubal   auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
343*852b6486SRafael Ubal                                                 inputShape, quantizedType);
344*852b6486SRafael Ubal   return storedValueClamped;
345*852b6486SRafael Ubal }
346*852b6486SRafael Ubal 
347*852b6486SRafael Ubal // Dequantize a scalar or ranked tensor input.
348*852b6486SRafael Ubal //
349*852b6486SRafael Ubal // See function 'convertRanked()' below for a description of the arguments.
350*852b6486SRafael Ubal Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
351*852b6486SRafael Ubal                       ArrayRef<OpFoldResult> inputShape, Value scale,
352*852b6486SRafael Ubal                       Value zeroPoint, QuantizedType quantizedType) {
353*852b6486SRafael Ubal   // Convert scale to tensor if necessary
354*852b6486SRafael Ubal   auto inputType = input.getType();
355*852b6486SRafael Ubal   scale = getScalarOrTensorConstant(
356*852b6486SRafael Ubal       builder, loc, scale, inputType, inputShape);
357*852b6486SRafael Ubal 
358*852b6486SRafael Ubal   // Convert stored value to float
359*852b6486SRafael Ubal   auto result = convertIntegerToFloat(
360*852b6486SRafael Ubal       builder, loc, input, scale.getType(), quantizedType.isSigned());
361*852b6486SRafael Ubal 
362*852b6486SRafael Ubal   // Skip unnecessary computations if no zero point is given
363*852b6486SRafael Ubal   if (!matchPattern(zeroPoint, m_Zero())) {
364*852b6486SRafael Ubal     // Convert zero point to tensor if necessary
365*852b6486SRafael Ubal     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
366*852b6486SRafael Ubal                                           inputShape);
367*852b6486SRafael Ubal 
368*852b6486SRafael Ubal     // Convert zero point from storage to expressed type
369*852b6486SRafael Ubal     zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
370*852b6486SRafael Ubal                                       scale.getType(),
371*852b6486SRafael Ubal                                       quantizedType.isSigned());
372*852b6486SRafael Ubal 
373*852b6486SRafael Ubal     // Subtract zero point to stored value
374*852b6486SRafael Ubal     result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
375*852b6486SRafael Ubal   }
376*852b6486SRafael Ubal 
377*852b6486SRafael Ubal   // Multiply by scale
378*852b6486SRafael Ubal   result = builder.create<arith::MulFOp>(loc, result, scale);
379*852b6486SRafael Ubal   return result;
380*852b6486SRafael Ubal }
381*852b6486SRafael Ubal 
382*852b6486SRafael Ubal // Convert a scalar or ranked tensor input with the given scale and zero point
383*852b6486SRafael Ubal // values.
384*852b6486SRafael Ubal //
385*852b6486SRafael Ubal // - input
386*852b6486SRafael Ubal //   Scalar or ranked tensor value.
387*852b6486SRafael Ubal //
388*852b6486SRafael Ubal // - inputShape
389*852b6486SRafael Ubal //   If 'input' is a tensor, combination or attributes/values representing its
390*852b6486SRafael Ubal //   static/dynamic dimensions. If 'input' is a scalar, empty list.
391*852b6486SRafael Ubal //
392*852b6486SRafael Ubal // - scale
393*852b6486SRafael Ubal //   Scale as a floating-point scalar value.
394*852b6486SRafael Ubal //
395*852b6486SRafael Ubal // - zeroPoint
396*852b6486SRafael Ubal //   Zero point as an integer scalar value.
397*852b6486SRafael Ubal //
398*852b6486SRafael Ubal // - quantizedType
399*852b6486SRafael Ubal //   Scalar quantized type of the result ('quant.qcast') or of the input
400*852b6486SRafael Ubal //   ('quant.dcast').
401*852b6486SRafael Ubal //
402*852b6486SRafael Ubal Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
403*852b6486SRafael Ubal                     Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
404*852b6486SRafael Ubal                     Value zeroPoint, QuantizedType quantizedType) {
405*852b6486SRafael Ubal   if (isa<QuantizeCastOp>(op))
406*852b6486SRafael Ubal     return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
407*852b6486SRafael Ubal                          quantizedType);
408*852b6486SRafael Ubal   if (isa<DequantizeCastOp>(op))
409*852b6486SRafael Ubal     return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
410*852b6486SRafael Ubal                            quantizedType);
411*852b6486SRafael Ubal   llvm_unreachable("unexpected quant op");
412*852b6486SRafael Ubal }
413*852b6486SRafael Ubal 
414*852b6486SRafael Ubal // Convert an operation using per-layer quantization with a scalar or ranked
415*852b6486SRafael Ubal // tensor input.
416*852b6486SRafael Ubal //
417*852b6486SRafael Ubal // - op
418*852b6486SRafael Ubal //   'quant.dcast' or 'quant.qcast' op.
419*852b6486SRafael Ubal //
420*852b6486SRafael Ubal // - input
421*852b6486SRafael Ubal //   Scalar or ranked tensor.
422*852b6486SRafael Ubal //
423*852b6486SRafael Ubal // - quantizedType
424*852b6486SRafael Ubal //   Per-layer quantized type.
425*852b6486SRafael Ubal //
426*852b6486SRafael Ubal Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
427*852b6486SRafael Ubal                             Value input, UniformQuantizedType quantizedType) {
428*852b6486SRafael Ubal   // Create scale and zero point constants
429*852b6486SRafael Ubal   auto expressedType = quantizedType.getExpressedType();
430*852b6486SRafael Ubal   auto storageType = quantizedType.getStorageType();
431*852b6486SRafael Ubal   auto scaleAttr =
432*852b6486SRafael Ubal       builder.getFloatAttr(expressedType, quantizedType.getScale());
433*852b6486SRafael Ubal   auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
434*852b6486SRafael Ubal   auto zeroPointAttr =
435*852b6486SRafael Ubal       builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
436*852b6486SRafael Ubal   auto zeroPoint =
437*852b6486SRafael Ubal       builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
438*852b6486SRafael Ubal 
439*852b6486SRafael Ubal   auto inputShape = getScalarOrTensorShape(builder, loc, input);
440*852b6486SRafael Ubal   return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
441*852b6486SRafael Ubal                        quantizedType);
442*852b6486SRafael Ubal }
443*852b6486SRafael Ubal 
444*852b6486SRafael Ubal // Convert an operation using per-layer quantization.
445*852b6486SRafael Ubal //
446*852b6486SRafael Ubal // - op
447*852b6486SRafael Ubal //   'quant.dcast' or 'quant.qcast' op.
448*852b6486SRafael Ubal //
449*852b6486SRafael Ubal // - input
450*852b6486SRafael Ubal //   Scalar, ranked tensor, or unranked tensor.
451*852b6486SRafael Ubal //
452*852b6486SRafael Ubal // - quantizedType
453*852b6486SRafael Ubal //   Per-layer quantized type.
454*852b6486SRafael Ubal //
455*852b6486SRafael Ubal Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
456*852b6486SRafael Ubal                       Value input, UniformQuantizedType quantizedType) {
457*852b6486SRafael Ubal   // Flatten input if unranked
458*852b6486SRafael Ubal   bool isUnranked = isa<UnrankedTensorType>(input.getType());
459*852b6486SRafael Ubal   Value inputShape;
460*852b6486SRafael Ubal   if (isUnranked)
461*852b6486SRafael Ubal     std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
462*852b6486SRafael Ubal 
463*852b6486SRafael Ubal   // Process ranked tensor
464*852b6486SRafael Ubal   auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
465*852b6486SRafael Ubal 
466*852b6486SRafael Ubal   // Restore original shape if unranked
467*852b6486SRafael Ubal   if (isUnranked)
468*852b6486SRafael Ubal     result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
469*852b6486SRafael Ubal 
470*852b6486SRafael Ubal   return result;
471*852b6486SRafael Ubal }
472*852b6486SRafael Ubal 
473*852b6486SRafael Ubal // Convert an operation using per-channel quantization and a scalar or ranked
474*852b6486SRafael Ubal // tensor as an input.
475*852b6486SRafael Ubal //
476*852b6486SRafael Ubal // - op
477*852b6486SRafael Ubal //   'quant.dcast' or 'quant.qcast' op.
478*852b6486SRafael Ubal //
479*852b6486SRafael Ubal // - input
480*852b6486SRafael Ubal //   Scalar or ranked tensor.
481*852b6486SRafael Ubal //
482*852b6486SRafael Ubal // - quantizedType
483*852b6486SRafael Ubal //   Per-channel quantized type.
484*852b6486SRafael Ubal //
485*852b6486SRafael Ubal Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
486*852b6486SRafael Ubal                               Value input,
487*852b6486SRafael Ubal                               UniformQuantizedPerAxisType quantizedType,
488*852b6486SRafael Ubal                               int64_t channelAxis) {
489*852b6486SRafael Ubal   auto *context = builder.getContext();
490*852b6486SRafael Ubal 
491*852b6486SRafael Ubal   auto inputType = cast<RankedTensorType>(input.getType());
492*852b6486SRafael Ubal   auto inputRank = inputType.getRank();
493*852b6486SRafael Ubal 
494*852b6486SRafael Ubal   auto scales = materializePerChannelScales(builder, loc, quantizedType);
495*852b6486SRafael Ubal   auto zeroPoints =
496*852b6486SRafael Ubal       materializePerChannelZeroPoints(builder, loc, quantizedType);
497*852b6486SRafael Ubal 
498*852b6486SRafael Ubal   auto elementType = isa<FloatType>(inputType.getElementType())
499*852b6486SRafael Ubal                          ? quantizedType.getStorageType()
500*852b6486SRafael Ubal                          : quantizedType.getExpressedType();
501*852b6486SRafael Ubal   auto initShape = tensor::getMixedSizes(builder, loc, input);
502*852b6486SRafael Ubal   Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
503*852b6486SRafael Ubal 
504*852b6486SRafael Ubal   SmallVector<utils::IteratorType> iteratorTypes(
505*852b6486SRafael Ubal       inputRank, utils::IteratorType::parallel);
506*852b6486SRafael Ubal   auto channelAxisAffineMap = AffineMap::get(
507*852b6486SRafael Ubal       inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
508*852b6486SRafael Ubal   SmallVector<AffineMap> indexingMaps{
509*852b6486SRafael Ubal     builder.getMultiDimIdentityMap(inputRank),
510*852b6486SRafael Ubal     channelAxisAffineMap,
511*852b6486SRafael Ubal     channelAxisAffineMap,
512*852b6486SRafael Ubal     builder.getMultiDimIdentityMap(inputRank)
513*852b6486SRafael Ubal   };
514*852b6486SRafael Ubal   auto result = builder.create<linalg::GenericOp>(
515*852b6486SRafael Ubal       loc,
516*852b6486SRafael Ubal       init.getType(),  // resultType
517*852b6486SRafael Ubal       ValueRange{input, scales, zeroPoints},  // inputs
518*852b6486SRafael Ubal       ValueRange{init},  // outputs
519*852b6486SRafael Ubal       indexingMaps,
520*852b6486SRafael Ubal       iteratorTypes,
521*852b6486SRafael Ubal       [&](OpBuilder& builder, Location loc, ValueRange args) {
522*852b6486SRafael Ubal         assert(args.size() == 4);
523*852b6486SRafael Ubal         auto input = args[0];
524*852b6486SRafael Ubal         auto scale = args[1];
525*852b6486SRafael Ubal         auto zeroPoint = args[2];
526*852b6486SRafael Ubal 
527*852b6486SRafael Ubal         auto result = convertRanked(builder, loc, op, input, {}, scale,
528*852b6486SRafael Ubal                                     zeroPoint, quantizedType);
529*852b6486SRafael Ubal 
530*852b6486SRafael Ubal         builder.create<linalg::YieldOp>(loc, result);
531*852b6486SRafael Ubal       })
532*852b6486SRafael Ubal       .getResult(0);
533*852b6486SRafael Ubal 
534*852b6486SRafael Ubal   return result;
535*852b6486SRafael Ubal }
536*852b6486SRafael Ubal 
537*852b6486SRafael Ubal // Convert an operation using per-channel quantization.
538*852b6486SRafael Ubal //
539*852b6486SRafael Ubal // - op
540*852b6486SRafael Ubal //   'quant.dcast' or 'quant.qcast' op.
541*852b6486SRafael Ubal //
542*852b6486SRafael Ubal // - input
543*852b6486SRafael Ubal //   Scalar, ranked tensor, or unranked tensor.
544*852b6486SRafael Ubal //
545*852b6486SRafael Ubal // - quantizedType
546*852b6486SRafael Ubal //   Per-channel quantized type.
547*852b6486SRafael Ubal //
548*852b6486SRafael Ubal Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
549*852b6486SRafael Ubal                         Value input,
550*852b6486SRafael Ubal                         UniformQuantizedPerAxisType quantizedType) {
551*852b6486SRafael Ubal   // Flatten unranked tensor into a 3D ranked tensor if necessary
552*852b6486SRafael Ubal   bool isUnranked = isa<UnrankedTensorType>(input.getType());
553*852b6486SRafael Ubal   int64_t channelAxis = quantizedType.getQuantizedDimension();
554*852b6486SRafael Ubal   int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
555*852b6486SRafael Ubal   Value inputShape;
556*852b6486SRafael Ubal   if (isUnranked) {
557*852b6486SRafael Ubal     std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
558*852b6486SRafael Ubal         builder, loc, input, channelAxis, channelAxisSize);
559*852b6486SRafael Ubal     channelAxis = 1;
560*852b6486SRafael Ubal   }
561*852b6486SRafael Ubal 
562*852b6486SRafael Ubal   // Work on a ranked tensor
563*852b6486SRafael Ubal   auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
564*852b6486SRafael Ubal                                         channelAxis);
565*852b6486SRafael Ubal 
566*852b6486SRafael Ubal   // Restore original tensor shape if unranked
567*852b6486SRafael Ubal   if (isUnranked)
568*852b6486SRafael Ubal     result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
569*852b6486SRafael Ubal 
570*852b6486SRafael Ubal   return result;
571*852b6486SRafael Ubal }
572*852b6486SRafael Ubal 
573*852b6486SRafael Ubal // Convert a quantization operation.
574*852b6486SRafael Ubal //
575*852b6486SRafael Ubal // - op
576*852b6486SRafael Ubal //   'quant.dcast' or 'quant.qcast' op.
577*852b6486SRafael Ubal //
578*852b6486SRafael Ubal // - input
579*852b6486SRafael Ubal //   Scalar, ranked tensor, or unranked tensor. The element type matches
580*852b6486SRafael Ubal //   the storage type (quant.dcast) or expressed type (quant.qcast) of
581*852b6486SRafael Ubal //   'quantizedType'.
582*852b6486SRafael Ubal //
583*852b6486SRafael Ubal // - quantizedType
584*852b6486SRafael Ubal //   Per-layer or per-channel quantized type.
585*852b6486SRafael Ubal //
586*852b6486SRafael Ubal Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
587*852b6486SRafael Ubal                        Value input, Type quantizedType) {
588*852b6486SRafael Ubal   if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
589*852b6486SRafael Ubal     return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
590*852b6486SRafael Ubal 
591*852b6486SRafael Ubal   if (auto uniformQuantizedPerAxisType =
592*852b6486SRafael Ubal           dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
593*852b6486SRafael Ubal     return convertPerChannel(builder, loc, op, input,
594*852b6486SRafael Ubal                              uniformQuantizedPerAxisType);
595*852b6486SRafael Ubal 
596*852b6486SRafael Ubal   llvm_unreachable("unexpected quantized type");
597*852b6486SRafael Ubal }
598*852b6486SRafael Ubal 
599*852b6486SRafael Ubal // Lowering pattern for 'quant.dcast'
600*852b6486SRafael Ubal struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
601*852b6486SRafael Ubal   using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
602*852b6486SRafael Ubal 
603*852b6486SRafael Ubal   LogicalResult
604*852b6486SRafael Ubal   matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
605*852b6486SRafael Ubal                   ConversionPatternRewriter &rewriter) const override {
606*852b6486SRafael Ubal     auto loc = op.getLoc();
607*852b6486SRafael Ubal     auto input = op.getInput();
608*852b6486SRafael Ubal     auto quantizedType =
609*852b6486SRafael Ubal         cast<QuantizedType>(getScalarType(op.getInput().getType()));
610*852b6486SRafael Ubal 
611*852b6486SRafael Ubal     // Convert quantized input to storage type
612*852b6486SRafael Ubal     auto storageScalarOrTensorType =
613*852b6486SRafael Ubal         getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
614*852b6486SRafael Ubal     input = rewriter.create<quant::StorageCastOp>(
615*852b6486SRafael Ubal         loc, storageScalarOrTensorType, input);
616*852b6486SRafael Ubal 
617*852b6486SRafael Ubal     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
618*852b6486SRafael Ubal 
619*852b6486SRafael Ubal     rewriter.replaceOp(op, result);
620*852b6486SRafael Ubal     return success();
621*852b6486SRafael Ubal   }
622*852b6486SRafael Ubal };
623*852b6486SRafael Ubal 
624*852b6486SRafael Ubal // Lowering pattern for 'quant.qcast'
625*852b6486SRafael Ubal struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
626*852b6486SRafael Ubal   using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
627*852b6486SRafael Ubal 
628*852b6486SRafael Ubal   LogicalResult
629*852b6486SRafael Ubal   matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
630*852b6486SRafael Ubal                   ConversionPatternRewriter &rewriter) const override {
631*852b6486SRafael Ubal     auto loc = op.getLoc();
632*852b6486SRafael Ubal     auto input = op.getInput();
633*852b6486SRafael Ubal     auto quantizedType = getScalarType(op.getResult().getType());
634*852b6486SRafael Ubal 
635*852b6486SRafael Ubal     // Flatten unranked tensor input
636*852b6486SRafael Ubal     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
637*852b6486SRafael Ubal 
638*852b6486SRafael Ubal     // Cast stored value to result quantized value
639*852b6486SRafael Ubal     rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
640*852b6486SRafael Ubal         op, op.getResult().getType(), result);
641*852b6486SRafael Ubal     return success();
642*852b6486SRafael Ubal   }
643*852b6486SRafael Ubal };
644*852b6486SRafael Ubal 
645*852b6486SRafael Ubal struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
646*852b6486SRafael Ubal   void runOnOperation() override {
647*852b6486SRafael Ubal     RewritePatternSet patterns(&getContext());
648*852b6486SRafael Ubal     populateLowerQuantOpsPatterns(patterns);
649*852b6486SRafael Ubal 
650*852b6486SRafael Ubal     ConversionTarget target(getContext());
651*852b6486SRafael Ubal     target.addLegalOp<quant::StorageCastOp>();
652*852b6486SRafael Ubal     target.addIllegalDialect<quant::QuantDialect>();
653*852b6486SRafael Ubal     target.addLegalDialect<
654*852b6486SRafael Ubal       arith::ArithDialect,
655*852b6486SRafael Ubal       linalg::LinalgDialect,
656*852b6486SRafael Ubal       shape::ShapeDialect,
657*852b6486SRafael Ubal       tensor::TensorDialect
658*852b6486SRafael Ubal     >();
659*852b6486SRafael Ubal 
660*852b6486SRafael Ubal     if (failed(applyPartialConversion(getOperation(), target,
661*852b6486SRafael Ubal                                       std::move(patterns))))
662*852b6486SRafael Ubal       signalPassFailure();
663*852b6486SRafael Ubal   }
664*852b6486SRafael Ubal };
665*852b6486SRafael Ubal 
666*852b6486SRafael Ubal } // namespace
667*852b6486SRafael Ubal 
668*852b6486SRafael Ubal void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
669*852b6486SRafael Ubal   patterns.add<
670*852b6486SRafael Ubal     DequantizeCastOpConversion,
671*852b6486SRafael Ubal     QuantizeCastOpConversion
672*852b6486SRafael Ubal   >(patterns.getContext());
673*852b6486SRafael Ubal }
674*852b6486SRafael Ubal 
675*852b6486SRafael Ubal } // namespace quant
676*852b6486SRafael Ubal } // namespace mlir
677