xref: /llvm-project/mlir/lib/Dialect/Quant/IR/QuantOps.cpp (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "QuantDialectBytecode.h"
10 #include "TypeDetail.h"
11 
12 #include "mlir/Dialect/Quant/IR/Quant.h"
13 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
19 
20 
21 namespace mlir {
22 namespace quant {
23 
24 namespace {
25 
26 // Verify the integrity of per-axis quantization information, if present.
27 //
28 // - quantizedType
29 //   Any quantized type. Any quantized type with no per-axis quantization is
30 //   ignored.
31 //
32 // - containerType
33 //   Original input or result type of the operation using the provided quantized
34 //   type. Used to ensure that the quantized type appears within a tensor and
35 //   that the tensor is compatible with per-axis quantization information.
36 //
37 LogicalResult verifyPerAxisQuantization(Operation *op,
38                                         QuantizedType quantizedType,
39                                         Type containerType) {
40   auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
41   if (!quantizedPerAxisType)
42     return success();
43 
44   auto tensorType = dyn_cast<TensorType>(containerType);
45   if (!tensorType)
46     return op->emitError("scalar types may not use per-axis quantization");
47 
48   if (!tensorType.hasRank())
49     return success();
50 
51   int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
52   if (quantizedDimension >= tensorType.getRank())
53     return op->emitError("quantized dimension must be less than tensor rank");
54 
55   int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
56   if (quantizedDimensionSize != ShapedType::kDynamic &&
57       quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
58     return op->emitError(
59         "quantized dimension size does not match number of scales");
60 
61   return success();
62 }
63 
64 // Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
65 //
66 // - quantizedType
67 //   Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
68 //   whether as a primitive type or in a tensor.
69 //
70 // - floatType
71 //   Float type used in the input ('quant.qcast') or result ('quant.dcast'),
72 //   whether as a primitive type or in a tensor.
73 //
74 // - containerType
75 //   Type of original input or result.
76 //
77 LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
78                                    FloatType floatType, Type containerType) {
79   if (quantizedType.getExpressedType() != floatType)
80     return op->emitError(
81         "expressed type in quantized type expected to match float type");
82 
83   // Veriy integrity of per-axis quantization information, if present.
84   return verifyPerAxisQuantization(op, quantizedType, containerType);
85 }
86 
87 }  // namespace
88 
89 
90 //===----------------------------------------------------------------------===//
91 // Dialect
92 //===----------------------------------------------------------------------===//
93 
94 void QuantDialect::initialize() {
95   addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
96            UniformQuantizedPerAxisType>();
97   addOperations<
98 #define GET_OP_LIST
99 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
100       >();
101   detail::addBytecodeInterface(this);
102 }
103 
104 
105 //===----------------------------------------------------------------------===//
106 // DequantizeCastOp
107 //===----------------------------------------------------------------------===//
108 
109 LogicalResult DequantizeCastOp::verify() {
110   return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
111                               getInput().getType());
112 }
113 
114 OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
115   // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
116   // with the value of x. Values x and y are guaranteed to be of the same type
117   // in this pattern.
118   auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
119   if (!srcQcastOp)
120     return {};
121   assert(srcQcastOp.getInput().getType() == getType());
122   return srcQcastOp.getInput();
123 }
124 
125 FloatType DequantizeCastOp::getFloatType() {
126   return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
127 }
128 
129 QuantizedType DequantizeCastOp::getQuantizedType() {
130   return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
131 }
132 
133 
134 //===----------------------------------------------------------------------===//
135 // QuantizeCastOp
136 //===----------------------------------------------------------------------===//
137 
138 LogicalResult QuantizeCastOp::verify() {
139   return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
140                               getInput().getType());
141 }
142 
143 OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
144   // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
145   // with the value of x if the casts invert each other. Contrary to the folding
146   // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
147   // x and y are not guaranteed to be of the same type here, as they may use
148   // different quantization parameters.
149   auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
150   if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
151     return {};
152   return srcDcastOp.getInput();
153 }
154 
155 FloatType QuantizeCastOp::getFloatType() {
156   return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
157 }
158 
159 QuantizedType QuantizeCastOp::getQuantizedType() {
160   return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
161 }
162 
163 
164 //===----------------------------------------------------------------------===//
165 // StorageCastOp
166 //===----------------------------------------------------------------------===//
167 
168 LogicalResult StorageCastOp::verify() {
169   auto quantizedType = getQuantizedType();
170   auto integerType = getIntegerType();
171   if (quantizedType.getStorageType() != integerType)
172     return emitError(
173         "storage type in quantized type expected to match integer type");
174 
175   // Verify integrity of per-axis quantization information, if available. While
176   // the quantization type may appear in the input or the result, their tensor
177   // shapes are guaranteed to be identical at this point.
178   return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
179 }
180 
181 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
182   // Matches x -> quant.scast -> quant.scast -> y, replacing the second
183   // quant.scast with the value of x if the casts invert each other.
184   auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
185   if (!srcScastOp || srcScastOp.getInput().getType() != getType())
186     return {};
187   return srcScastOp.getInput();
188 }
189 
190 IntegerType StorageCastOp::getIntegerType() {
191   auto inputScalarType = getElementTypeOrSelf(getInput().getType());
192   if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
193     return integerType;
194 
195   auto resultScalarType = getElementTypeOrSelf(getResult().getType());
196   return cast<IntegerType>(resultScalarType);
197 }
198 
199 QuantizedType StorageCastOp::getQuantizedType() {
200   auto inputScalarType = getElementTypeOrSelf(getInput().getType());
201   if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
202     return quantizedType;
203 
204   auto resultScalarType = getElementTypeOrSelf(getResult().getType());
205   return cast<QuantizedType>(resultScalarType);
206 }
207 
208 
209 } // namespace quant
210 } // namespace mlir
211 
212 #define GET_OP_CLASSES
213 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
214 
215