xref: /llvm-project/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1 //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Quant/IR/Quant.h"
17 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
18 #include "mlir/Dialect/Quant/Transforms/Passes.h"
19 #include "mlir/Dialect/Shape/IR/Shape.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 namespace mlir {
26 namespace quant {
27 
28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30 
31 namespace {
32 
33 // If 'inputType' is a tensor, return its element type. If it is a scalar,
34 // return it as is.
35 Type getScalarType(Type inputType) {
36   if (auto tensorType = dyn_cast<TensorType>(inputType))
37     return tensorType.getElementType();
38   return inputType;
39 }
40 
41 // Return the shape of an input value as a list of attributes (static dimensions)
42 // and values (dynamic dimensions). If 'input' is a scalar, an empty list is
43 // returned. If 'input' is a tensor, its shape is returned.
44 SmallVector<OpFoldResult>
45 getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
46   if (isa<TensorType>(input.getType()))
47     return tensor::getMixedSizes(builder, loc, input);
48   return {};
49 }
50 
51 // If 'referenceType' is a scalar, return 'elementType' as is. If
52 // 'referenceType' is a tensor, return another tensor with the same shape and
53 // elements of type 'elementType'.
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55   if (auto tensorType = dyn_cast<TensorType>(referenceType))
56     return tensorType.clone(elementType);
57   return elementType;
58 }
59 
60 // Return a constant with the given value. If 'referenceType' is a tensor, a
61 // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62 // scalar, 'referenceShape' is ignored and a scalar constant is returned.
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64                                 Type referenceType,
65                                 ArrayRef<OpFoldResult> referenceShape) {
66   // If the result type is a scalar, return the unmodified scalar constant.
67   auto tensorType = dyn_cast<TensorType>(referenceType);
68   if (!tensorType) {
69     assert(referenceShape.empty());
70     return scalar;
71   }
72 
73   // Create tensor splat
74   auto tensorConstant =
75       builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76   return tensorConstant;
77 }
78 
79 // Reshape an unranked tensor into a 1D ranked tensor.
80 //
81 // - input
82 //   Unranked tensor.
83 //
84 // Return values:
85 //
86 // - flatInput
87 //   1D ranked, dynamically shaped tensor.
88 //
89 // - inputShape
90 //   1D extent tensor containing the shape of the original unranked input.
91 //
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93                                               Value input) {
94   // Get unranked input shape and total size
95   auto *context = builder.getContext();
96   auto shapeType = shape::getExtentTensorType(context);
97   auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98   Value inputSize = builder.create<shape::NumElementsOp>(
99       loc, builder.getIndexType(), inputShape);
100 
101   // Turn input size into 1D tensor
102   auto flatShapeType = shape::getExtentTensorType(context, 1);
103   auto flatInputShape = builder.create<tensor::FromElementsOp>(
104       loc, flatShapeType, inputSize);
105 
106   // Reshape input tensor into 1D
107   auto inputType = cast<UnrankedTensorType>(input.getType());
108   auto elementType = inputType.getElementType();
109   auto flatInputType =
110       RankedTensorType::get({ShapedType::kDynamic}, elementType);
111   auto flatInput = builder.create<tensor::ReshapeOp>(
112       loc, flatInputType, input, flatInputShape);
113   return std::make_pair(flatInput, inputShape);
114 }
115 
116 // Reshape an unranked tensor into a 3D ranked tensor where the central
117 // dimension of the result tensor corresponds to dimension 'axis' of the input
118 // tensor.
119 //
120 // - input
121 //   Unranked tensor.
122 //
123 // - axis
124 //   Index of the input dimension around which other input dimiensions will be
125 //   collapsed.
126 //
127 // - axisSize
128 //   Size of input dimension 'axis'.
129 //
130 // Return values:
131 //
132 // - flatInput
133 //   3D ranked tensor of shape [?, axisSize, ?].
134 //
135 // - inputShape
136 //   1D extent tensor containing the shape of the original unranked input.
137 //
138 std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
139                                                         Location loc,
140                                                         Value input,
141                                                         int64_t axis,
142                                                         int64_t axisSize) {
143   // Get full tensor shape
144   auto *context = builder.getContext();
145   auto indexType = builder.getIndexType();
146   auto shapeType = shape::getExtentTensorType(context);
147   auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
148 
149   // Get shape and sizes on left and right of axis
150   auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
151   auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
152   auto shapeLeft = builder.create<shape::SplitAtOp>(
153       loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
154       .getResult(0);
155   auto sizeLeft = builder.create<shape::NumElementsOp>(
156       loc, indexType, shapeLeft);
157   auto shapeRight = builder.create<shape::SplitAtOp>(
158       loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
159       .getResult(1);
160   auto sizeRight = builder.create<shape::NumElementsOp>(
161       loc, indexType, shapeRight);
162 
163   // Compute flat input shape as a 3-element 1D tensor
164   auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
165   auto flatShapeType = shape::getExtentTensorType(context, 3);
166   auto flatInputShape = builder.create<tensor::FromElementsOp>(
167       loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
168 
169   // Reshape input to 3D tensor
170   auto inputType = cast<UnrankedTensorType>(input.getType());
171   auto elementType = inputType.getElementType();
172   auto flatInputType = RankedTensorType::get(
173       {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
174   auto flatInput = builder.create<tensor::ReshapeOp>(
175       loc, flatInputType, input, flatInputShape);
176 
177   return std::make_pair(flatInput, inputShape);
178 }
179 
180 // Reshape an input tensor into its original unranked shape.
181 //
182 // - input
183 //   Ranked tensor.
184 //
185 // - inputShape
186 //   1D extent tensor.
187 //
188 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
189                                  Value inputShape) {
190   auto inputType = cast<RankedTensorType>(input.getType());
191   auto elementType = inputType.getElementType();
192   auto unrankedType = UnrankedTensorType::get(elementType);
193   return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
194 }
195 
196 // Create a tensor constant containing all scales in a per-channel quantized
197 // type. Example:
198 //
199 //   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
200 //
201 // produces
202 //
203 //   %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
204 //
205 Value materializePerChannelScales(OpBuilder &builder, Location loc,
206                                   UniformQuantizedPerAxisType quantizedType) {
207   auto scales = quantizedType.getScales();
208   auto expressedType = quantizedType.getExpressedType();
209   auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
210     return builder.getFloatAttr(expressedType, scale);
211   });
212   auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
213   auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
214   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
215 }
216 
217 // Create a tensor constant containing all zero points in a per-channel
218 // quantized type. Example:
219 //
220 //   !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
221 //
222 // produces
223 //
224 //   %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
225 //
226 Value materializePerChannelZeroPoints(
227     OpBuilder &builder, Location loc,
228     UniformQuantizedPerAxisType quantizedType) {
229   auto zeroPoints = quantizedType.getZeroPoints();
230   auto storageType = quantizedType.getStorageType();
231   auto zeroPointAttrs = llvm::map_to_vector(
232       zeroPoints,
233       [&](int64_t zeroPoint) -> Attribute {
234         return builder.getIntegerAttr(storageType, zeroPoint);
235       });
236   auto tensorType =
237       RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
238   auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
239   return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
240 }
241 
242 // Clamp the given scalar or tensor input using the storage bounds encoded in
243 // the given quantized type, if present.
244 //
245 // - input
246 //   Scalar or ranked tensor input. The element type must match the storage type
247 //   of 'quantizedType'.
248 //
249 // - inputShape
250 //   If 'input' is a tensor, combination of attributes/values representing its
251 //   static/dynamic dimensions. If 'input' is a scalar, empty list.
252 //
253 // - quantizedType
254 //   Per-axis or per-channel quantized type.
255 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
256                           ArrayRef<OpFoldResult> inputShape,
257                           QuantizedType quantizedType) {
258   // If quantized type does not narrow down the storage type range, there is
259   // nothing to do.
260   if (!quantizedType.hasStorageTypeBounds())
261     return input;
262 
263   // Materialize bounds
264   auto inputType = input.getType();
265   auto storageType = quantizedType.getStorageType();
266   auto storageMinScalar = builder.create<arith::ConstantIntOp>(
267       loc, quantizedType.getStorageTypeMin(), storageType);
268   auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
269       loc, quantizedType.getStorageTypeMax(), storageType);
270   auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
271                                               inputType, inputShape);
272   auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
273                                               inputType, inputShape);
274 
275   // Clamp
276   if (quantizedType.isSigned()) {
277     input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
278     input = builder.create<arith::MinSIOp>(loc, input, storageMax);
279   } else {
280     input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
281     input = builder.create<arith::MinUIOp>(loc, input, storageMax);
282   }
283   return input;
284 }
285 
286 // Emit op 'arith.fptosi' or 'arith.fptoui'.
287 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
288                             Type resultType, bool isSigned) {
289   if (isSigned)
290     return builder.create<arith::FPToSIOp>(loc, resultType, input);
291   return builder.create<arith::FPToUIOp>(loc, resultType, input);
292 }
293 
294 // Emit op 'arith.sitofp' or 'arith.uitofp'.
295 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
296                             Type resultType, bool isSigned) {
297   if (isSigned)
298     return builder.create<arith::SIToFPOp>(loc, resultType, input);
299   return builder.create<arith::UIToFPOp>(loc, resultType, input);
300 }
301 
302 // Quantize a scalar or ranked tensor value. The stored value is clamped using
303 // the storage bounds encoded in the given quantized type.
304 //
305 // See function 'convertRanked()' below for a description of the arguments.
306 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
307                     ArrayRef<OpFoldResult> inputShape, Value scale,
308                     Value zeroPoint, QuantizedType quantizedType) {
309   // Convert scale to tensor if necessary
310   auto inputType = input.getType();
311   scale = getScalarOrTensorConstant(
312       builder, loc, scale, inputType, inputShape);
313 
314   // Scale input
315   auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
316 
317   // Skip unnecessary computations if no zero point is given
318   Value storedValueFloat = scaledValue;
319   if (!matchPattern(zeroPoint, m_Zero())) {
320     // Convert zero point to tensor if necessary
321     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
322                                           inputShape);
323 
324     // Convert zero point from storage to expressed type
325     zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
326                                       scale.getType(),
327                                       quantizedType.isSigned());
328 
329     // Add zero point to stored value
330     storedValueFloat =
331         builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
332   }
333 
334   // Convert stored value to storage type
335   auto storageScalarOrTensorType =
336       getScalarOrTensorType(quantizedType.getStorageType(), inputType);
337   auto storedValueInt = convertFloatToInteger(
338       builder, loc, storedValueFloat, storageScalarOrTensorType,
339       quantizedType.isSigned());
340 
341   // Clamp stored value it if the storage type is bound
342   auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
343                                                 inputShape, quantizedType);
344   return storedValueClamped;
345 }
346 
347 // Dequantize a scalar or ranked tensor input.
348 //
349 // See function 'convertRanked()' below for a description of the arguments.
350 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
351                       ArrayRef<OpFoldResult> inputShape, Value scale,
352                       Value zeroPoint, QuantizedType quantizedType) {
353   // Convert scale to tensor if necessary
354   auto inputType = input.getType();
355   scale = getScalarOrTensorConstant(
356       builder, loc, scale, inputType, inputShape);
357 
358   // Convert stored value to float
359   auto result = convertIntegerToFloat(
360       builder, loc, input, scale.getType(), quantizedType.isSigned());
361 
362   // Skip unnecessary computations if no zero point is given
363   if (!matchPattern(zeroPoint, m_Zero())) {
364     // Convert zero point to tensor if necessary
365     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
366                                           inputShape);
367 
368     // Convert zero point from storage to expressed type
369     zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
370                                       scale.getType(),
371                                       quantizedType.isSigned());
372 
373     // Subtract zero point to stored value
374     result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
375   }
376 
377   // Multiply by scale
378   result = builder.create<arith::MulFOp>(loc, result, scale);
379   return result;
380 }
381 
382 // Convert a scalar or ranked tensor input with the given scale and zero point
383 // values.
384 //
385 // - input
386 //   Scalar or ranked tensor value.
387 //
388 // - inputShape
389 //   If 'input' is a tensor, combination or attributes/values representing its
390 //   static/dynamic dimensions. If 'input' is a scalar, empty list.
391 //
392 // - scale
393 //   Scale as a floating-point scalar value.
394 //
395 // - zeroPoint
396 //   Zero point as an integer scalar value.
397 //
398 // - quantizedType
399 //   Scalar quantized type of the result ('quant.qcast') or of the input
400 //   ('quant.dcast').
401 //
402 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
403                     Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
404                     Value zeroPoint, QuantizedType quantizedType) {
405   if (isa<QuantizeCastOp>(op))
406     return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
407                          quantizedType);
408   if (isa<DequantizeCastOp>(op))
409     return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
410                            quantizedType);
411   llvm_unreachable("unexpected quant op");
412 }
413 
414 // Convert an operation using per-layer quantization with a scalar or ranked
415 // tensor input.
416 //
417 // - op
418 //   'quant.dcast' or 'quant.qcast' op.
419 //
420 // - input
421 //   Scalar or ranked tensor.
422 //
423 // - quantizedType
424 //   Per-layer quantized type.
425 //
426 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
427                             Value input, UniformQuantizedType quantizedType) {
428   // Create scale and zero point constants
429   auto expressedType = quantizedType.getExpressedType();
430   auto storageType = quantizedType.getStorageType();
431   auto scaleAttr =
432       builder.getFloatAttr(expressedType, quantizedType.getScale());
433   auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
434   auto zeroPointAttr =
435       builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
436   auto zeroPoint =
437       builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
438 
439   auto inputShape = getScalarOrTensorShape(builder, loc, input);
440   return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
441                        quantizedType);
442 }
443 
444 // Convert an operation using per-layer quantization.
445 //
446 // - op
447 //   'quant.dcast' or 'quant.qcast' op.
448 //
449 // - input
450 //   Scalar, ranked tensor, or unranked tensor.
451 //
452 // - quantizedType
453 //   Per-layer quantized type.
454 //
455 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
456                       Value input, UniformQuantizedType quantizedType) {
457   // Flatten input if unranked
458   bool isUnranked = isa<UnrankedTensorType>(input.getType());
459   Value inputShape;
460   if (isUnranked)
461     std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
462 
463   // Process ranked tensor
464   auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
465 
466   // Restore original shape if unranked
467   if (isUnranked)
468     result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
469 
470   return result;
471 }
472 
473 // Convert an operation using per-channel quantization and a scalar or ranked
474 // tensor as an input.
475 //
476 // - op
477 //   'quant.dcast' or 'quant.qcast' op.
478 //
479 // - input
480 //   Scalar or ranked tensor.
481 //
482 // - quantizedType
483 //   Per-channel quantized type.
484 //
485 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
486                               Value input,
487                               UniformQuantizedPerAxisType quantizedType,
488                               int64_t channelAxis) {
489   auto *context = builder.getContext();
490 
491   auto inputType = cast<RankedTensorType>(input.getType());
492   auto inputRank = inputType.getRank();
493 
494   auto scales = materializePerChannelScales(builder, loc, quantizedType);
495   auto zeroPoints =
496       materializePerChannelZeroPoints(builder, loc, quantizedType);
497 
498   auto elementType = isa<FloatType>(inputType.getElementType())
499                          ? quantizedType.getStorageType()
500                          : quantizedType.getExpressedType();
501   auto initShape = tensor::getMixedSizes(builder, loc, input);
502   Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
503 
504   SmallVector<utils::IteratorType> iteratorTypes(
505       inputRank, utils::IteratorType::parallel);
506   auto channelAxisAffineMap = AffineMap::get(
507       inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
508   SmallVector<AffineMap> indexingMaps{
509     builder.getMultiDimIdentityMap(inputRank),
510     channelAxisAffineMap,
511     channelAxisAffineMap,
512     builder.getMultiDimIdentityMap(inputRank)
513   };
514   auto result = builder.create<linalg::GenericOp>(
515       loc,
516       init.getType(),  // resultType
517       ValueRange{input, scales, zeroPoints},  // inputs
518       ValueRange{init},  // outputs
519       indexingMaps,
520       iteratorTypes,
521       [&](OpBuilder& builder, Location loc, ValueRange args) {
522         assert(args.size() == 4);
523         auto input = args[0];
524         auto scale = args[1];
525         auto zeroPoint = args[2];
526 
527         auto result = convertRanked(builder, loc, op, input, {}, scale,
528                                     zeroPoint, quantizedType);
529 
530         builder.create<linalg::YieldOp>(loc, result);
531       })
532       .getResult(0);
533 
534   return result;
535 }
536 
537 // Convert an operation using per-channel quantization.
538 //
539 // - op
540 //   'quant.dcast' or 'quant.qcast' op.
541 //
542 // - input
543 //   Scalar, ranked tensor, or unranked tensor.
544 //
545 // - quantizedType
546 //   Per-channel quantized type.
547 //
548 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
549                         Value input,
550                         UniformQuantizedPerAxisType quantizedType) {
551   // Flatten unranked tensor into a 3D ranked tensor if necessary
552   bool isUnranked = isa<UnrankedTensorType>(input.getType());
553   int64_t channelAxis = quantizedType.getQuantizedDimension();
554   int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
555   Value inputShape;
556   if (isUnranked) {
557     std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
558         builder, loc, input, channelAxis, channelAxisSize);
559     channelAxis = 1;
560   }
561 
562   // Work on a ranked tensor
563   auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
564                                         channelAxis);
565 
566   // Restore original tensor shape if unranked
567   if (isUnranked)
568     result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
569 
570   return result;
571 }
572 
573 // Convert a quantization operation.
574 //
575 // - op
576 //   'quant.dcast' or 'quant.qcast' op.
577 //
578 // - input
579 //   Scalar, ranked tensor, or unranked tensor. The element type matches
580 //   the storage type (quant.dcast) or expressed type (quant.qcast) of
581 //   'quantizedType'.
582 //
583 // - quantizedType
584 //   Per-layer or per-channel quantized type.
585 //
586 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
587                        Value input, Type quantizedType) {
588   if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
589     return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
590 
591   if (auto uniformQuantizedPerAxisType =
592           dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
593     return convertPerChannel(builder, loc, op, input,
594                              uniformQuantizedPerAxisType);
595 
596   llvm_unreachable("unexpected quantized type");
597 }
598 
599 // Lowering pattern for 'quant.dcast'
600 struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
601   using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
602 
603   LogicalResult
604   matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
605                   ConversionPatternRewriter &rewriter) const override {
606     auto loc = op.getLoc();
607     auto input = op.getInput();
608     auto quantizedType =
609         cast<QuantizedType>(getScalarType(op.getInput().getType()));
610 
611     // Convert quantized input to storage type
612     auto storageScalarOrTensorType =
613         getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
614     input = rewriter.create<quant::StorageCastOp>(
615         loc, storageScalarOrTensorType, input);
616 
617     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
618 
619     rewriter.replaceOp(op, result);
620     return success();
621   }
622 };
623 
624 // Lowering pattern for 'quant.qcast'
625 struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
626   using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
627 
628   LogicalResult
629   matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
630                   ConversionPatternRewriter &rewriter) const override {
631     auto loc = op.getLoc();
632     auto input = op.getInput();
633     auto quantizedType = getScalarType(op.getResult().getType());
634 
635     // Flatten unranked tensor input
636     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
637 
638     // Cast stored value to result quantized value
639     rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
640         op, op.getResult().getType(), result);
641     return success();
642   }
643 };
644 
645 struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
646   void runOnOperation() override {
647     RewritePatternSet patterns(&getContext());
648     populateLowerQuantOpsPatterns(patterns);
649 
650     ConversionTarget target(getContext());
651     target.addLegalOp<quant::StorageCastOp>();
652     target.addIllegalDialect<quant::QuantDialect>();
653     target.addLegalDialect<
654       arith::ArithDialect,
655       linalg::LinalgDialect,
656       shape::ShapeDialect,
657       tensor::TensorDialect
658     >();
659 
660     if (failed(applyPartialConversion(getOperation(), target,
661                                       std::move(patterns))))
662       signalPassFailure();
663   }
664 };
665 
666 } // namespace
667 
668 void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
669   patterns.add<
670     DequantizeCastOpConversion,
671     QuantizeCastOpConversion
672   >(patterns.getContext());
673 }
674 
675 } // namespace quant
676 } // namespace mlir
677