1 //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Transforms `quant.dcast` and `quant.qcast` into lower-level ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Quant/IR/Quant.h" 17 #include "mlir/Dialect/Quant/IR/QuantTypes.h" 18 #include "mlir/Dialect/Quant/Transforms/Passes.h" 19 #include "mlir/Dialect/Shape/IR/Shape.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 namespace mlir { 26 namespace quant { 27 28 #define GEN_PASS_DEF_LOWERQUANTOPS 29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc" 30 31 namespace { 32 33 // If 'inputType' is a tensor, return its element type. If it is a scalar, 34 // return it as is. 35 Type getScalarType(Type inputType) { 36 if (auto tensorType = dyn_cast<TensorType>(inputType)) 37 return tensorType.getElementType(); 38 return inputType; 39 } 40 41 // Return the shape of an input value as a list of attributes (static dimensions) 42 // and values (dynamic dimensions). If 'input' is a scalar, an empty list is 43 // returned. If 'input' is a tensor, its shape is returned. 44 SmallVector<OpFoldResult> 45 getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { 46 if (isa<TensorType>(input.getType())) 47 return tensor::getMixedSizes(builder, loc, input); 48 return {}; 49 } 50 51 // If 'referenceType' is a scalar, return 'elementType' as is. If 52 // 'referenceType' is a tensor, return another tensor with the same shape and 53 // elements of type 'elementType'. 54 Type getScalarOrTensorType(Type elementType, Type referenceType) { 55 if (auto tensorType = dyn_cast<TensorType>(referenceType)) 56 return tensorType.clone(elementType); 57 return elementType; 58 } 59 60 // Return a constant with the given value. If 'referenceType' is a tensor, a 61 // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a 62 // scalar, 'referenceShape' is ignored and a scalar constant is returned. 63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, 64 Type referenceType, 65 ArrayRef<OpFoldResult> referenceShape) { 66 // If the result type is a scalar, return the unmodified scalar constant. 67 auto tensorType = dyn_cast<TensorType>(referenceType); 68 if (!tensorType) { 69 assert(referenceShape.empty()); 70 return scalar; 71 } 72 73 // Create tensor splat 74 auto tensorConstant = 75 builder.create<tensor::SplatOp>(loc, scalar, referenceShape); 76 return tensorConstant; 77 } 78 79 // Reshape an unranked tensor into a 1D ranked tensor. 80 // 81 // - input 82 // Unranked tensor. 83 // 84 // Return values: 85 // 86 // - flatInput 87 // 1D ranked, dynamically shaped tensor. 88 // 89 // - inputShape 90 // 1D extent tensor containing the shape of the original unranked input. 91 // 92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc, 93 Value input) { 94 // Get unranked input shape and total size 95 auto *context = builder.getContext(); 96 auto shapeType = shape::getExtentTensorType(context); 97 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input); 98 Value inputSize = builder.create<shape::NumElementsOp>( 99 loc, builder.getIndexType(), inputShape); 100 101 // Turn input size into 1D tensor 102 auto flatShapeType = shape::getExtentTensorType(context, 1); 103 auto flatInputShape = builder.create<tensor::FromElementsOp>( 104 loc, flatShapeType, inputSize); 105 106 // Reshape input tensor into 1D 107 auto inputType = cast<UnrankedTensorType>(input.getType()); 108 auto elementType = inputType.getElementType(); 109 auto flatInputType = 110 RankedTensorType::get({ShapedType::kDynamic}, elementType); 111 auto flatInput = builder.create<tensor::ReshapeOp>( 112 loc, flatInputType, input, flatInputShape); 113 return std::make_pair(flatInput, inputShape); 114 } 115 116 // Reshape an unranked tensor into a 3D ranked tensor where the central 117 // dimension of the result tensor corresponds to dimension 'axis' of the input 118 // tensor. 119 // 120 // - input 121 // Unranked tensor. 122 // 123 // - axis 124 // Index of the input dimension around which other input dimiensions will be 125 // collapsed. 126 // 127 // - axisSize 128 // Size of input dimension 'axis'. 129 // 130 // Return values: 131 // 132 // - flatInput 133 // 3D ranked tensor of shape [?, axisSize, ?]. 134 // 135 // - inputShape 136 // 1D extent tensor containing the shape of the original unranked input. 137 // 138 std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder, 139 Location loc, 140 Value input, 141 int64_t axis, 142 int64_t axisSize) { 143 // Get full tensor shape 144 auto *context = builder.getContext(); 145 auto indexType = builder.getIndexType(); 146 auto shapeType = shape::getExtentTensorType(context); 147 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input); 148 149 // Get shape and sizes on left and right of axis 150 auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis); 151 auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1); 152 auto shapeLeft = builder.create<shape::SplitAtOp>( 153 loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) 154 .getResult(0); 155 auto sizeLeft = builder.create<shape::NumElementsOp>( 156 loc, indexType, shapeLeft); 157 auto shapeRight = builder.create<shape::SplitAtOp>( 158 loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) 159 .getResult(1); 160 auto sizeRight = builder.create<shape::NumElementsOp>( 161 loc, indexType, shapeRight); 162 163 // Compute flat input shape as a 3-element 1D tensor 164 auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize); 165 auto flatShapeType = shape::getExtentTensorType(context, 3); 166 auto flatInputShape = builder.create<tensor::FromElementsOp>( 167 loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); 168 169 // Reshape input to 3D tensor 170 auto inputType = cast<UnrankedTensorType>(input.getType()); 171 auto elementType = inputType.getElementType(); 172 auto flatInputType = RankedTensorType::get( 173 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); 174 auto flatInput = builder.create<tensor::ReshapeOp>( 175 loc, flatInputType, input, flatInputShape); 176 177 return std::make_pair(flatInput, inputShape); 178 } 179 180 // Reshape an input tensor into its original unranked shape. 181 // 182 // - input 183 // Ranked tensor. 184 // 185 // - inputShape 186 // 1D extent tensor. 187 // 188 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, 189 Value inputShape) { 190 auto inputType = cast<RankedTensorType>(input.getType()); 191 auto elementType = inputType.getElementType(); 192 auto unrankedType = UnrankedTensorType::get(elementType); 193 return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape); 194 } 195 196 // Create a tensor constant containing all scales in a per-channel quantized 197 // type. Example: 198 // 199 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}> 200 // 201 // produces 202 // 203 // %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> 204 // 205 Value materializePerChannelScales(OpBuilder &builder, Location loc, 206 UniformQuantizedPerAxisType quantizedType) { 207 auto scales = quantizedType.getScales(); 208 auto expressedType = quantizedType.getExpressedType(); 209 auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { 210 return builder.getFloatAttr(expressedType, scale); 211 }); 212 auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType); 213 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); 214 return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr); 215 } 216 217 // Create a tensor constant containing all zero points in a per-channel 218 // quantized type. Example: 219 // 220 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}> 221 // 222 // produces 223 // 224 // %cst = arith.constant dense<[10, 20]> : tensor<2xi8> 225 // 226 Value materializePerChannelZeroPoints( 227 OpBuilder &builder, Location loc, 228 UniformQuantizedPerAxisType quantizedType) { 229 auto zeroPoints = quantizedType.getZeroPoints(); 230 auto storageType = quantizedType.getStorageType(); 231 auto zeroPointAttrs = llvm::map_to_vector( 232 zeroPoints, 233 [&](int64_t zeroPoint) -> Attribute { 234 return builder.getIntegerAttr(storageType, zeroPoint); 235 }); 236 auto tensorType = 237 RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); 238 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); 239 return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr); 240 } 241 242 // Clamp the given scalar or tensor input using the storage bounds encoded in 243 // the given quantized type, if present. 244 // 245 // - input 246 // Scalar or ranked tensor input. The element type must match the storage type 247 // of 'quantizedType'. 248 // 249 // - inputShape 250 // If 'input' is a tensor, combination of attributes/values representing its 251 // static/dynamic dimensions. If 'input' is a scalar, empty list. 252 // 253 // - quantizedType 254 // Per-axis or per-channel quantized type. 255 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, 256 ArrayRef<OpFoldResult> inputShape, 257 QuantizedType quantizedType) { 258 // If quantized type does not narrow down the storage type range, there is 259 // nothing to do. 260 if (!quantizedType.hasStorageTypeBounds()) 261 return input; 262 263 // Materialize bounds 264 auto inputType = input.getType(); 265 auto storageType = quantizedType.getStorageType(); 266 auto storageMinScalar = builder.create<arith::ConstantIntOp>( 267 loc, quantizedType.getStorageTypeMin(), storageType); 268 auto storageMaxScalar = builder.create<arith::ConstantIntOp>( 269 loc, quantizedType.getStorageTypeMax(), storageType); 270 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, 271 inputType, inputShape); 272 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar, 273 inputType, inputShape); 274 275 // Clamp 276 if (quantizedType.isSigned()) { 277 input = builder.create<arith::MaxSIOp>(loc, input, storageMin); 278 input = builder.create<arith::MinSIOp>(loc, input, storageMax); 279 } else { 280 input = builder.create<arith::MaxUIOp>(loc, input, storageMin); 281 input = builder.create<arith::MinUIOp>(loc, input, storageMax); 282 } 283 return input; 284 } 285 286 // Emit op 'arith.fptosi' or 'arith.fptoui'. 287 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, 288 Type resultType, bool isSigned) { 289 if (isSigned) 290 return builder.create<arith::FPToSIOp>(loc, resultType, input); 291 return builder.create<arith::FPToUIOp>(loc, resultType, input); 292 } 293 294 // Emit op 'arith.sitofp' or 'arith.uitofp'. 295 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, 296 Type resultType, bool isSigned) { 297 if (isSigned) 298 return builder.create<arith::SIToFPOp>(loc, resultType, input); 299 return builder.create<arith::UIToFPOp>(loc, resultType, input); 300 } 301 302 // Quantize a scalar or ranked tensor value. The stored value is clamped using 303 // the storage bounds encoded in the given quantized type. 304 // 305 // See function 'convertRanked()' below for a description of the arguments. 306 Value quantizeValue(OpBuilder &builder, Location loc, Value input, 307 ArrayRef<OpFoldResult> inputShape, Value scale, 308 Value zeroPoint, QuantizedType quantizedType) { 309 // Convert scale to tensor if necessary 310 auto inputType = input.getType(); 311 scale = getScalarOrTensorConstant( 312 builder, loc, scale, inputType, inputShape); 313 314 // Scale input 315 auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale); 316 317 // Skip unnecessary computations if no zero point is given 318 Value storedValueFloat = scaledValue; 319 if (!matchPattern(zeroPoint, m_Zero())) { 320 // Convert zero point to tensor if necessary 321 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, 322 inputShape); 323 324 // Convert zero point from storage to expressed type 325 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, 326 scale.getType(), 327 quantizedType.isSigned()); 328 329 // Add zero point to stored value 330 storedValueFloat = 331 builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint); 332 } 333 334 // Convert stored value to storage type 335 auto storageScalarOrTensorType = 336 getScalarOrTensorType(quantizedType.getStorageType(), inputType); 337 auto storedValueInt = convertFloatToInteger( 338 builder, loc, storedValueFloat, storageScalarOrTensorType, 339 quantizedType.isSigned()); 340 341 // Clamp stored value it if the storage type is bound 342 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt, 343 inputShape, quantizedType); 344 return storedValueClamped; 345 } 346 347 // Dequantize a scalar or ranked tensor input. 348 // 349 // See function 'convertRanked()' below for a description of the arguments. 350 Value dequantizeValue(OpBuilder &builder, Location loc, Value input, 351 ArrayRef<OpFoldResult> inputShape, Value scale, 352 Value zeroPoint, QuantizedType quantizedType) { 353 // Convert scale to tensor if necessary 354 auto inputType = input.getType(); 355 scale = getScalarOrTensorConstant( 356 builder, loc, scale, inputType, inputShape); 357 358 // Convert stored value to float 359 auto result = convertIntegerToFloat( 360 builder, loc, input, scale.getType(), quantizedType.isSigned()); 361 362 // Skip unnecessary computations if no zero point is given 363 if (!matchPattern(zeroPoint, m_Zero())) { 364 // Convert zero point to tensor if necessary 365 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, 366 inputShape); 367 368 // Convert zero point from storage to expressed type 369 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, 370 scale.getType(), 371 quantizedType.isSigned()); 372 373 // Subtract zero point to stored value 374 result = builder.create<arith::SubFOp>(loc, result, zeroPoint); 375 } 376 377 // Multiply by scale 378 result = builder.create<arith::MulFOp>(loc, result, scale); 379 return result; 380 } 381 382 // Convert a scalar or ranked tensor input with the given scale and zero point 383 // values. 384 // 385 // - input 386 // Scalar or ranked tensor value. 387 // 388 // - inputShape 389 // If 'input' is a tensor, combination or attributes/values representing its 390 // static/dynamic dimensions. If 'input' is a scalar, empty list. 391 // 392 // - scale 393 // Scale as a floating-point scalar value. 394 // 395 // - zeroPoint 396 // Zero point as an integer scalar value. 397 // 398 // - quantizedType 399 // Scalar quantized type of the result ('quant.qcast') or of the input 400 // ('quant.dcast'). 401 // 402 Value convertRanked(OpBuilder &builder, Location loc, Operation *op, 403 Value input, ArrayRef<OpFoldResult> inputShape, Value scale, 404 Value zeroPoint, QuantizedType quantizedType) { 405 if (isa<QuantizeCastOp>(op)) 406 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, 407 quantizedType); 408 if (isa<DequantizeCastOp>(op)) 409 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint, 410 quantizedType); 411 llvm_unreachable("unexpected quant op"); 412 } 413 414 // Convert an operation using per-layer quantization with a scalar or ranked 415 // tensor input. 416 // 417 // - op 418 // 'quant.dcast' or 'quant.qcast' op. 419 // 420 // - input 421 // Scalar or ranked tensor. 422 // 423 // - quantizedType 424 // Per-layer quantized type. 425 // 426 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, 427 Value input, UniformQuantizedType quantizedType) { 428 // Create scale and zero point constants 429 auto expressedType = quantizedType.getExpressedType(); 430 auto storageType = quantizedType.getStorageType(); 431 auto scaleAttr = 432 builder.getFloatAttr(expressedType, quantizedType.getScale()); 433 auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr); 434 auto zeroPointAttr = 435 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); 436 auto zeroPoint = 437 builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr); 438 439 auto inputShape = getScalarOrTensorShape(builder, loc, input); 440 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, 441 quantizedType); 442 } 443 444 // Convert an operation using per-layer quantization. 445 // 446 // - op 447 // 'quant.dcast' or 'quant.qcast' op. 448 // 449 // - input 450 // Scalar, ranked tensor, or unranked tensor. 451 // 452 // - quantizedType 453 // Per-layer quantized type. 454 // 455 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, 456 Value input, UniformQuantizedType quantizedType) { 457 // Flatten input if unranked 458 bool isUnranked = isa<UnrankedTensorType>(input.getType()); 459 Value inputShape; 460 if (isUnranked) 461 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); 462 463 // Process ranked tensor 464 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType); 465 466 // Restore original shape if unranked 467 if (isUnranked) 468 result = restoreUnrankedTensorShape(builder, loc, result, inputShape); 469 470 return result; 471 } 472 473 // Convert an operation using per-channel quantization and a scalar or ranked 474 // tensor as an input. 475 // 476 // - op 477 // 'quant.dcast' or 'quant.qcast' op. 478 // 479 // - input 480 // Scalar or ranked tensor. 481 // 482 // - quantizedType 483 // Per-channel quantized type. 484 // 485 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, 486 Value input, 487 UniformQuantizedPerAxisType quantizedType, 488 int64_t channelAxis) { 489 auto *context = builder.getContext(); 490 491 auto inputType = cast<RankedTensorType>(input.getType()); 492 auto inputRank = inputType.getRank(); 493 494 auto scales = materializePerChannelScales(builder, loc, quantizedType); 495 auto zeroPoints = 496 materializePerChannelZeroPoints(builder, loc, quantizedType); 497 498 auto elementType = isa<FloatType>(inputType.getElementType()) 499 ? quantizedType.getStorageType() 500 : quantizedType.getExpressedType(); 501 auto initShape = tensor::getMixedSizes(builder, loc, input); 502 Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType); 503 504 SmallVector<utils::IteratorType> iteratorTypes( 505 inputRank, utils::IteratorType::parallel); 506 auto channelAxisAffineMap = AffineMap::get( 507 inputRank, 0, builder.getAffineDimExpr(channelAxis), context); 508 SmallVector<AffineMap> indexingMaps{ 509 builder.getMultiDimIdentityMap(inputRank), 510 channelAxisAffineMap, 511 channelAxisAffineMap, 512 builder.getMultiDimIdentityMap(inputRank) 513 }; 514 auto result = builder.create<linalg::GenericOp>( 515 loc, 516 init.getType(), // resultType 517 ValueRange{input, scales, zeroPoints}, // inputs 518 ValueRange{init}, // outputs 519 indexingMaps, 520 iteratorTypes, 521 [&](OpBuilder& builder, Location loc, ValueRange args) { 522 assert(args.size() == 4); 523 auto input = args[0]; 524 auto scale = args[1]; 525 auto zeroPoint = args[2]; 526 527 auto result = convertRanked(builder, loc, op, input, {}, scale, 528 zeroPoint, quantizedType); 529 530 builder.create<linalg::YieldOp>(loc, result); 531 }) 532 .getResult(0); 533 534 return result; 535 } 536 537 // Convert an operation using per-channel quantization. 538 // 539 // - op 540 // 'quant.dcast' or 'quant.qcast' op. 541 // 542 // - input 543 // Scalar, ranked tensor, or unranked tensor. 544 // 545 // - quantizedType 546 // Per-channel quantized type. 547 // 548 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, 549 Value input, 550 UniformQuantizedPerAxisType quantizedType) { 551 // Flatten unranked tensor into a 3D ranked tensor if necessary 552 bool isUnranked = isa<UnrankedTensorType>(input.getType()); 553 int64_t channelAxis = quantizedType.getQuantizedDimension(); 554 int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); 555 Value inputShape; 556 if (isUnranked) { 557 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( 558 builder, loc, input, channelAxis, channelAxisSize); 559 channelAxis = 1; 560 } 561 562 // Work on a ranked tensor 563 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType, 564 channelAxis); 565 566 // Restore original tensor shape if unranked 567 if (isUnranked) 568 result = restoreUnrankedTensorShape(builder, loc, result, inputShape); 569 570 return result; 571 } 572 573 // Convert a quantization operation. 574 // 575 // - op 576 // 'quant.dcast' or 'quant.qcast' op. 577 // 578 // - input 579 // Scalar, ranked tensor, or unranked tensor. The element type matches 580 // the storage type (quant.dcast) or expressed type (quant.qcast) of 581 // 'quantizedType'. 582 // 583 // - quantizedType 584 // Per-layer or per-channel quantized type. 585 // 586 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, 587 Value input, Type quantizedType) { 588 if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType)) 589 return convertPerLayer(builder, loc, op, input, uniformQuantizedType); 590 591 if (auto uniformQuantizedPerAxisType = 592 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) 593 return convertPerChannel(builder, loc, op, input, 594 uniformQuantizedPerAxisType); 595 596 llvm_unreachable("unexpected quantized type"); 597 } 598 599 // Lowering pattern for 'quant.dcast' 600 struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> { 601 using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern; 602 603 LogicalResult 604 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, 605 ConversionPatternRewriter &rewriter) const override { 606 auto loc = op.getLoc(); 607 auto input = op.getInput(); 608 auto quantizedType = 609 cast<QuantizedType>(getScalarType(op.getInput().getType())); 610 611 // Convert quantized input to storage type 612 auto storageScalarOrTensorType = 613 getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); 614 input = rewriter.create<quant::StorageCastOp>( 615 loc, storageScalarOrTensorType, input); 616 617 auto result = convertQuantized(rewriter, loc, op, input, quantizedType); 618 619 rewriter.replaceOp(op, result); 620 return success(); 621 } 622 }; 623 624 // Lowering pattern for 'quant.qcast' 625 struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> { 626 using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern; 627 628 LogicalResult 629 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor, 630 ConversionPatternRewriter &rewriter) const override { 631 auto loc = op.getLoc(); 632 auto input = op.getInput(); 633 auto quantizedType = getScalarType(op.getResult().getType()); 634 635 // Flatten unranked tensor input 636 auto result = convertQuantized(rewriter, loc, op, input, quantizedType); 637 638 // Cast stored value to result quantized value 639 rewriter.replaceOpWithNewOp<quant::StorageCastOp>( 640 op, op.getResult().getType(), result); 641 return success(); 642 } 643 }; 644 645 struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> { 646 void runOnOperation() override { 647 RewritePatternSet patterns(&getContext()); 648 populateLowerQuantOpsPatterns(patterns); 649 650 ConversionTarget target(getContext()); 651 target.addLegalOp<quant::StorageCastOp>(); 652 target.addIllegalDialect<quant::QuantDialect>(); 653 target.addLegalDialect< 654 arith::ArithDialect, 655 linalg::LinalgDialect, 656 shape::ShapeDialect, 657 tensor::TensorDialect 658 >(); 659 660 if (failed(applyPartialConversion(getOperation(), target, 661 std::move(patterns)))) 662 signalPassFailure(); 663 } 664 }; 665 666 } // namespace 667 668 void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { 669 patterns.add< 670 DequantizeCastOpConversion, 671 QuantizeCastOpConversion 672 >(patterns.getContext()); 673 } 674 675 } // namespace quant 676 } // namespace mlir 677