1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // \file 10 // This file implements the TOSA Specification: 11 // https://developer.mlplatform.org/w/tosa/ 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 17 #include "mlir/Dialect/Quant/IR/Quant.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/DialectImplementation.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Interfaces/InferTypeOpInterface.h" 28 #include "mlir/Transforms/InliningUtils.h" 29 #include "llvm/ADT/APFloat.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 33 #include <numeric> 34 35 using namespace mlir; 36 using namespace mlir::tosa; 37 38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc" 39 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 40 41 //===----------------------------------------------------------------------===// 42 // Tosa dialect interface includes. 43 //===----------------------------------------------------------------------===// 44 45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" 46 47 namespace { 48 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" 49 50 //===----------------------------------------------------------------------===// 51 // Dialect Function Inliner Interface. 52 //===----------------------------------------------------------------------===// 53 struct TosaInlinerInterface : public DialectInlinerInterface { 54 using DialectInlinerInterface::DialectInlinerInterface; 55 56 //===--------------------------------------------------------------------===// 57 // Analysis Hooks. 58 //===--------------------------------------------------------------------===// 59 60 /// All operations can be inlined by default. 61 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, 62 IRMapping &map) const final { 63 return true; 64 } 65 66 /// All regions with If and While parent operators can be inlined. 67 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 68 IRMapping &map) const final { 69 return (isa<tosa::IfOp>(dest->getParentOp()) || 70 isa<tosa::WhileOp>(dest->getParentOp())); 71 } 72 }; 73 74 /// This class implements the bytecode interface for the Tosa dialect. 75 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { 76 TosaDialectBytecodeInterface(Dialect *dialect) 77 : BytecodeDialectInterface(dialect) {} 78 79 //===--------------------------------------------------------------------===// 80 // Attributes 81 82 Attribute readAttribute(DialectBytecodeReader &reader) const override { 83 return ::readAttribute(getContext(), reader); 84 } 85 86 LogicalResult writeAttribute(Attribute attr, 87 DialectBytecodeWriter &writer) const override { 88 return ::writeAttribute(attr, writer); 89 } 90 91 //===--------------------------------------------------------------------===// 92 // Types 93 94 Type readType(DialectBytecodeReader &reader) const override { 95 return ::readType(getContext(), reader); 96 } 97 98 LogicalResult writeType(Type type, 99 DialectBytecodeWriter &writer) const override { 100 return ::writeType(type, writer); 101 } 102 103 void writeVersion(DialectBytecodeWriter &writer) const final { 104 // TODO: Populate. 105 } 106 107 std::unique_ptr<DialectVersion> 108 readVersion(DialectBytecodeReader &reader) const final { 109 // TODO: Populate 110 reader.emitError("Dialect does not support versioning"); 111 return nullptr; 112 } 113 114 LogicalResult upgradeFromVersion(Operation *topLevelOp, 115 const DialectVersion &version) const final { 116 return success(); 117 } 118 }; 119 120 } // namespace 121 122 //===----------------------------------------------------------------------===// 123 // TOSA control flow support. 124 //===----------------------------------------------------------------------===// 125 126 /// Returns the while loop body. 127 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; } 128 129 //===----------------------------------------------------------------------===// 130 // Tosa dialect initialization. 131 //===----------------------------------------------------------------------===// 132 133 void TosaDialect::initialize() { 134 addTypes< 135 #define GET_TYPEDEF_LIST 136 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" 137 >(); 138 addOperations< 139 #define GET_OP_LIST 140 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" 141 >(); 142 addAttributes< 143 #define GET_ATTRDEF_LIST 144 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" 145 >(); 146 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>(); 147 declarePromisedInterfaces< 148 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, 149 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp, 150 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, 151 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, 152 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp, 153 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, 154 GreaterEqualOp, MatMulOp>(); 155 } 156 157 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, 158 Type type, Location loc) { 159 // Tosa dialect constants only support ElementsAttr unlike standard dialect 160 // constant which supports all attributes. 161 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) { 162 return builder.create<tosa::ConstShapeOp>( 163 loc, type, llvm::cast<DenseIntElementsAttr>(value)); 164 } 165 if (llvm::isa<ElementsAttr>(value)) 166 return builder.create<tosa::ConstOp>(loc, type, 167 llvm::cast<ElementsAttr>(value)); 168 return nullptr; 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Parsers and printers 173 //===----------------------------------------------------------------------===// 174 175 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, 176 Attribute &attr) { 177 if (succeeded(parser.parseOptionalEqual())) { 178 if (failed(parser.parseAttribute(attr))) { 179 return parser.emitError(parser.getCurrentLocation()) 180 << "expected attribute"; 181 } 182 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { 183 typeAttr = TypeAttr::get(typedAttr.getType()); 184 } 185 return success(); 186 } 187 188 Type type; 189 if (failed(parser.parseColonType(type))) { 190 return parser.emitError(parser.getCurrentLocation()) << "expected type"; 191 } 192 typeAttr = TypeAttr::get(type); 193 194 return success(); 195 } 196 197 void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, 198 Attribute attr) { 199 bool needsSpace = false; 200 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr); 201 if (!typedAttr || typedAttr.getType() != type.getValue()) { 202 p << ": "; 203 p.printAttribute(type); 204 needsSpace = true; // subsequent attr value needs a space separator 205 } 206 if (attr) { 207 if (needsSpace) 208 p << ' '; 209 p << "= "; 210 p.printAttribute(attr); 211 } 212 } 213 214 //===----------------------------------------------------------------------===// 215 // TOSA Operator Verifiers. 216 //===----------------------------------------------------------------------===// 217 218 template <typename T> 219 static LogicalResult verifyConvOp(T op) { 220 // All TOSA conv ops have an input() and weight(). 221 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); 222 223 RankedTensorType weightType; 224 if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) 225 weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType()); 226 else 227 weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); 228 229 // Must be ranked tensor types 230 if (!inputType) { 231 op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); 232 return failure(); 233 } 234 if (!weightType) { 235 if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) { 236 op.emitOpError("expect a ranked tensor for filter, got ") 237 << op.getFilter(); 238 } else { 239 op.emitOpError("expect a ranked tensor for weight, got ") 240 << op.getWeight(); 241 } 242 return failure(); 243 } 244 245 auto inputEType = inputType.getElementType(); 246 auto weightEType = weightType.getElementType(); 247 248 bool inputIsQuant = !llvm::isa<FloatType>(inputEType); 249 bool weightIsQuant = !llvm::isa<FloatType>(weightEType); 250 251 // Either both must be quantized or both unquantized. 252 if (inputIsQuant != weightIsQuant) { 253 op.emitOpError( 254 "expect both input and weight to be float or not together, got ") 255 << inputEType << " and " << weightEType; 256 return failure(); 257 } 258 259 // Quantized type must have constructed the quantizationattr, and unquantized 260 // types should not have a quantizationattr. 261 if ((inputIsQuant && !op.getQuantizationInfo()) || 262 (!inputIsQuant && op.getQuantizationInfo())) { 263 op.emitOpError("quantizationattr is required for quantized type, and not " 264 "allowed for float type"); 265 return failure(); 266 } 267 return success(); 268 } 269 270 LogicalResult tosa::ConstOp::verify() { 271 272 auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType()); 273 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType()); 274 275 if (!attrType || !outputType) { 276 emitOpError("expected tensors for attr/result type"); 277 return failure(); 278 } 279 280 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>( 281 outputType.getElementType())) { 282 if (result.getStorageType() == attrType.getElementType()) 283 return success(); 284 } 285 286 if (attrType.getElementType() != outputType.getElementType()) { 287 emitOpError("expected same attr/result element types"); 288 return failure(); 289 } 290 291 return success(); 292 } 293 294 template <typename T> 295 static LogicalResult verifyConvOpModes(T op) { 296 auto inputEType = 297 llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); 298 299 if (auto quantType = 300 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) 301 inputEType = quantType.getStorageType(); 302 303 auto accType = op.getAccType(); 304 if (inputEType.isInteger(8) && !accType.isInteger(32)) 305 return op.emitOpError("accumulator type for i8 tensor is not i32"); 306 307 if (inputEType.isInteger(16) && !accType.isInteger(48)) 308 return op.emitOpError("accumulator type for i16 tensor is not i48"); 309 310 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16()) 311 return op.emitOpError("accumulator type for f8 tensor is not f16"); 312 313 if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) 314 return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); 315 316 if (inputEType.isBF16() && !accType.isF32()) 317 return op.emitOpError("accumulator type for bf16 tensor is not f32"); 318 319 if (inputEType.isF32() && !accType.isF32()) 320 return op.emitOpError("accumulator type for f32 tensor is not f32"); 321 322 return success(); 323 } 324 325 LogicalResult tosa::ArgMaxOp::verify() { 326 // Ensure output is of 32-bit integer 327 const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); 328 if (!resultETy.isIntOrIndex()) 329 return emitOpError("result tensor is not of integer type"); 330 331 // Ensure axis is within the tensor rank 332 const auto inputType = llvm::cast<ShapedType>(getInput().getType()); 333 const int64_t axis = getAxisAttr().getInt(); 334 if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) 335 return emitOpError("specified axis is outside the rank of the tensor"); 336 337 return success(); 338 } 339 340 LogicalResult tosa::AvgPool2dOp::verify() { 341 auto inputType = llvm::cast<ShapedType>(getInput().getType()); 342 343 auto inputETy = inputType.getElementType(); 344 auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); 345 346 if (auto quantType = 347 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) 348 inputETy = quantType.getStorageType(); 349 350 if (auto quantType = 351 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy)) 352 resultETy = quantType.getStorageType(); 353 354 auto accType = getAccType(); 355 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32)) 356 return emitOpError("accumulator type for integer tensor is not i32"); 357 358 if (inputETy.isF16() && !(accType.isF16() || accType.isF32())) 359 return emitOpError("accumulator type for f16 tensor is not f16/f32"); 360 361 if (inputETy.isBF16() && !accType.isF32()) 362 return emitOpError("accumulator type for bf16 tensor is not f32"); 363 364 if (inputETy.isF32() && !accType.isF32()) 365 return emitOpError("accumulator type for f32 tensor is not f32"); 366 367 if ((inputETy.isF32() && resultETy.isF32()) || 368 (inputETy.isF16() && resultETy.isF16()) || 369 (inputETy.isBF16() && resultETy.isBF16()) || 370 (inputETy.isInteger(8) && resultETy.isInteger(8)) || 371 (inputETy.isInteger(16) && resultETy.isInteger(16))) 372 return success(); 373 374 return emitOpError("input/output element types are incompatible."); 375 } 376 377 LogicalResult tosa::ClampOp::verify() { 378 mlir::Type inputETy = 379 llvm::cast<ShapedType>(getInput().getType()).getElementType(); 380 if (auto quantType = 381 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { 382 inputETy = quantType.getStorageType(); 383 } 384 mlir::Type maxFpType = getMaxFpAttr().getType(); 385 mlir::Type minFpType = getMinFpAttr().getType(); 386 mlir::Type outputETy = 387 llvm::cast<ShapedType>(getOutput().getType()).getElementType(); 388 if (auto quantType = 389 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { 390 outputETy = quantType.getStorageType(); 391 } 392 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); 393 394 if (inputETy != outputETy) 395 return emitOpError("input/output element types are incompatible."); 396 397 // If input datatype is float, check that the two min/max_fp attributes 398 // share the same type and that their type is either the same of the input's 399 // datatype, or a float type whose bitwidth > input datatype bitwidth. 400 if (!inputETy.isInteger(dataTypeBitWidth)) { 401 if (((maxFpType != minFpType) || 402 (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= 403 inputETy.getIntOrFloatBitWidth()))) 404 return emitOpError("min/max attributes types are incompatible with " 405 "input/output element types."); 406 } 407 408 return success(); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // TOSA Operator Quantization Builders. 413 //===----------------------------------------------------------------------===// 414 415 /// This builder is called on all convolution operators except TransposeConv, 416 /// which has specialized output shape semantics. The builder also defines the 417 /// bitwidth of the output given the bit width of the input & weight content. 418 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, 419 Type outputType, Value input, Value weight, 420 Value bias, DenseI64ArrayAttr pad, 421 DenseI64ArrayAttr stride, 422 DenseI64ArrayAttr dilation, 423 TypeAttr accType) { 424 425 result.addOperands({input, weight, bias}); 426 result.addAttribute("pad", pad); 427 result.addAttribute("stride", stride); 428 result.addAttribute("dilation", dilation); 429 result.addAttribute("acc_type", accType); 430 431 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); 432 if (quantAttr) { 433 result.addAttribute("quantization_info", quantAttr); 434 result.addTypes( 435 buildConvOpResultTypeInfo(builder, outputType, input, weight)); 436 } else { 437 result.addTypes(outputType); 438 } 439 } 440 441 /// Handles tosa.transpose_conv2d which has outpad and output shape 442 /// attributes. 443 static void buildTransConvOpWithQuantInfo( 444 OpBuilder &builder, OperationState &result, Type outputType, Value input, 445 Value weight, Value bias, DenseI64ArrayAttr outpad, 446 DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) { 447 result.addOperands({input, weight, bias}); 448 result.addAttribute("out_pad", outpad); 449 result.addAttribute("stride", stride); 450 result.addAttribute("out_shape", outputShape); 451 result.addAttribute("acc_type", accType); 452 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); 453 454 if (quantAttr) { 455 result.addAttribute("quantization_info", quantAttr); 456 result.addTypes( 457 buildConvOpResultTypeInfo(builder, outputType, input, weight)); 458 } else { 459 result.addTypes(outputType); 460 } 461 } 462 463 /// The tosa.fully_connected op has its own builder as it does not have 464 /// strides/dilation/padding. 465 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, 466 Type outputType, Value input, Value weight, 467 Value bias) { 468 469 result.addOperands({input, weight, bias}); 470 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); 471 if (quantAttr) { 472 result.addAttribute("quantization_info", quantAttr); 473 result.addTypes( 474 buildConvOpResultTypeInfo(builder, outputType, input, weight)); 475 } else { 476 result.addTypes(outputType); 477 } 478 } 479 480 /// The tosa.matmul op is also intended to be generated where a 481 /// fully_connected op must be constructed where the weight is not a constant. 482 /// In this case, the fully_connected op must be expressed using matmul. 483 /// TODO: Add link to the leglization document explaining this. 484 static void buildMatMulOpWithQuantInfo(OpBuilder &builder, 485 OperationState &result, Type outputType, 486 Value a, Value b) { 487 result.addOperands({a, b}); 488 auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); 489 490 if (quantAttr) { 491 result.addAttribute("quantization_info", quantAttr); 492 493 auto inputType = llvm::dyn_cast<ShapedType>(a.getType()); 494 assert(inputType && "Input must be a shaped tensor type!"); 495 496 auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( 497 inputType.getElementType()); 498 assert(inputQType && "Tensor must have quantized datatype!"); 499 500 unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); 501 502 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType); 503 assert(outputShapedType && "Output must be a shaped type"); 504 505 IntegerType accElementType; 506 if (inputBits == 16) 507 accElementType = builder.getIntegerType(48); 508 else 509 accElementType = builder.getI32Type(); 510 auto accType = outputShapedType.clone(accElementType); 511 result.addTypes(accType); 512 } else { 513 result.addTypes(outputType); 514 } 515 } 516 517 /// Both the tosa.avg_pool2d and unary ops use the same 518 /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it 519 /// has additional parameters not part of the unary ops. 520 static void 521 buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, 522 Type outputType, Value input, 523 DenseArrayAttr kernel, DenseArrayAttr stride, 524 DenseArrayAttr pad, TypeAttr accType) { 525 result.addOperands(input); 526 result.addAttribute("kernel", kernel); 527 result.addAttribute("stride", stride); 528 result.addAttribute("pad", pad); 529 result.addAttribute("acc_type", accType); 530 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); 531 if (quantAttr) 532 result.addAttribute("quantization_info", quantAttr); 533 result.types.push_back(outputType); 534 } 535 536 /// This builder is called on single-parameter unary operators that have scale 537 /// relationship between their input and output, expressed by the 538 /// UnaryOpQuantizationAttr. 539 static void buildUnaryOpWithQuantInfo(OpBuilder &builder, 540 OperationState &result, Type outputType, 541 Value input) { 542 result.addOperands(input); 543 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); 544 if (quantAttr) 545 result.addAttribute("quantization_info", quantAttr); 546 result.types.push_back(outputType); 547 } 548 549 /// This builder is called on TOSA pad operator that needs to create its own 550 /// OptionalAttr quantization_attr parameter to scale the padding values 551 /// correctly. No pad_const is interpreted as zero-padding. 552 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, 553 Type outputType, Value input, 554 Value paddings) { 555 result.addOperands({input, paddings}); 556 auto quantAttr = buildPadOpQuantizationAttr(builder, input); 557 if (quantAttr) 558 result.addAttribute("quantization_info", quantAttr); 559 result.types.push_back(outputType); 560 } 561 562 /// This builder is called on TOSA pad operator when an explicit pad_const 563 /// value is passed in. It also optionally constructs quantization_attr. 564 static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, 565 OperationState &result, 566 Type outputType, Value input, 567 Value paddings, 568 Value padConst) { 569 result.addOperands({input, paddings, padConst}); 570 auto quantAttr = buildPadOpQuantizationAttr(builder, input); 571 if (quantAttr) 572 result.addAttribute("quantization_info", quantAttr); 573 result.types.push_back(outputType); 574 } 575 576 //===----------------------------------------------------------------------===// 577 // TOSA Operator Return Type Inference. 578 //===----------------------------------------------------------------------===// 579 580 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, 581 SmallVector<int64_t> &outShape) { 582 int64_t outRank = 0; 583 for (int i = 0, e = operands.size(); i != e; ++i) { 584 auto shape = operands.getShape(i); 585 if (!shape.hasRank()) { 586 // TODO(jennik): Update function to have better case handling for 587 // invalid operands and for ranked tensors. 588 return failure(); 589 } 590 outRank = std::max<int64_t>(outRank, shape.getRank()); 591 } 592 593 outShape.resize(outRank, 1); 594 595 for (int i = 0, e = operands.size(); i != e; ++i) { 596 auto shape = operands.getShape(i); 597 auto rankDiff = outShape.size() - shape.getRank(); 598 599 for (size_t i = 0, e = shape.getRank(); i < e; ++i) { 600 auto dim1 = outShape[i + rankDiff]; 601 auto dim2 = shape.getDimSize(i); 602 auto resolvedDim = dim1; 603 604 if (dim1 == 1) { 605 resolvedDim = dim2; 606 } else if (dim2 == 1) { 607 resolvedDim = dim1; 608 } else if (dim1 != dim2) { 609 return failure(); 610 } 611 outShape[i + rankDiff] = resolvedDim; 612 } 613 } 614 615 return success(); 616 } 617 618 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( 619 MLIRContext *context, ::std::optional<Location> location, 620 ArgMaxOp::Adaptor adaptor, 621 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 622 ShapeAdaptor inputShape(adaptor.getInput().getType()); 623 IntegerAttr axis = adaptor.getProperties().axis; 624 int32_t axisVal = axis.getValue().getSExtValue(); 625 626 if (!inputShape.hasRank()) { 627 inferredReturnShapes.push_back(ShapedTypeComponents()); 628 return success(); 629 } 630 631 SmallVector<int64_t> outShape; 632 outShape.reserve(inputShape.getRank() - 1); 633 for (int i = 0, s = inputShape.getRank(); i < s; i++) { 634 if (i == axisVal) 635 continue; 636 outShape.push_back(inputShape.getDimSize(i)); 637 } 638 639 inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 640 return success(); 641 } 642 643 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( 644 MLIRContext *context, ::std::optional<Location> location, 645 RFFT2dOp::Adaptor adaptor, 646 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 647 ShapeAdaptor inputShape(adaptor.getInput().getType()); 648 649 if (!inputShape.hasRank()) 650 return failure(); 651 652 llvm::SmallVector<int64_t> outputShape; 653 outputShape.resize(3, ShapedType::kDynamic); 654 outputShape[0] = inputShape.getDimSize(0); 655 outputShape[1] = inputShape.getDimSize(1); 656 int64_t inWidth = inputShape.getDimSize(2); 657 658 // Note that we can support this calculation symbolically 659 // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] 660 if (inWidth != ShapedType::kDynamic) 661 outputShape[2] = inWidth / 2 + 1; 662 663 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 664 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 665 666 return success(); 667 } 668 669 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( 670 MLIRContext *context, ::std::optional<Location> location, 671 FFT2dOp::Adaptor adaptor, 672 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 673 inferredReturnShapes.push_back( 674 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType()))); 675 inferredReturnShapes.push_back( 676 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType()))); 677 return success(); 678 } 679 680 LogicalResult tosa::ConcatOp::inferReturnTypeComponents( 681 MLIRContext *context, ::std::optional<Location> location, 682 ConcatOp::Adaptor adaptor, 683 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 684 // Infer all dimension sizes by reducing based on inputs. 685 const Properties &prop = adaptor.getProperties(); 686 int32_t axis = prop.axis.getValue().getSExtValue(); 687 llvm::SmallVector<int64_t> outputShape; 688 bool hasRankedInput = false; 689 for (auto operand : adaptor.getOperands()) { 690 ShapeAdaptor operandShape(operand.getType()); 691 if (!operandShape.hasRank()) 692 continue; 693 694 // Copy the Operand's rank. 695 if (!hasRankedInput) 696 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); 697 698 // Copy shapes until the dim is non-dynamic. 699 for (int i = 0, s = operandShape.getRank(); i < s; i++) { 700 if (i == axis || operandShape.isDynamicDim(i)) 701 continue; 702 if (outputShape[i] == ShapedType::kDynamic) 703 outputShape[i] = operandShape.getDimSize(i); 704 if (outputShape[i] != operandShape.getDimSize(i)) 705 return emitOptionalError(location, 706 "Cannot concat tensors with different sizes" 707 " on the non-axis dimension ", 708 i); 709 } 710 711 hasRankedInput = true; 712 } 713 Type inputType = 714 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType(); 715 if (!hasRankedInput) { 716 inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); 717 return success(); 718 } 719 720 // Determine the dimension size along the concatenation axis. 721 int64_t concatDimSize = 0; 722 for (auto operand : adaptor.getOperands()) { 723 ShapeAdaptor operandShape(operand.getType()); 724 725 // We need to know the length of the concatenation axis of all inputs to 726 // determine the dimension size of the output shape. 727 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { 728 concatDimSize = ShapedType::kDynamic; 729 break; 730 } 731 732 concatDimSize += operandShape.getDimSize(axis); 733 } 734 735 outputShape[axis] = concatDimSize; 736 737 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); 738 return success(); 739 } 740 741 LogicalResult tosa::EqualOp::inferReturnTypeComponents( 742 MLIRContext *context, ::std::optional<Location> location, 743 ValueShapeRange operands, DictionaryAttr attributes, 744 OpaqueProperties properties, RegionRange regions, 745 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 746 auto elementType = IntegerType::get(context, /*width=*/1); 747 748 llvm::SmallVector<int64_t> outShape; 749 if (resolveBroadcastShape(operands, outShape).failed()) { 750 inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); 751 return success(); 752 } 753 754 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); 755 return success(); 756 } 757 758 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 759 if (l.size() != r.size() || l.size() != 1) 760 return false; 761 return succeeded(verifyCompatibleShape(l[0], r[0])); 762 } 763 764 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( 765 MLIRContext *context, ::std::optional<Location> location, 766 FullyConnectedOp::Adaptor adaptor, 767 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 768 ShapeAdaptor inputShape(adaptor.getInput().getType()); 769 ShapeAdaptor weightShape(adaptor.getWeight().getType()); 770 ShapeAdaptor biasShape(adaptor.getBias().getType()); 771 772 // All shapes are dynamic. 773 SmallVector<int64_t> outShape; 774 outShape.resize(2, ShapedType::kDynamic); 775 776 if (inputShape.hasRank()) { 777 outShape[0] = inputShape.getDimSize(0); 778 } 779 780 if (weightShape.hasRank()) { 781 outShape[1] = weightShape.getDimSize(0); 782 } 783 784 if (biasShape.hasRank()) { 785 outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) 786 : outShape[1]; 787 } 788 789 inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 790 return success(); 791 } 792 793 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } 794 795 LogicalResult tosa::MatMulOp::inferReturnTypeComponents( 796 MLIRContext *context, ::std::optional<Location> location, 797 MatMulOp::Adaptor adaptor, 798 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 799 ShapeAdaptor lhsShape(adaptor.getA().getType()); 800 ShapeAdaptor rhsShape(adaptor.getB().getType()); 801 802 // All shapes are dynamic. 803 SmallVector<int64_t> outShape; 804 outShape.resize(3, ShapedType::kDynamic); 805 806 if (lhsShape.hasRank()) { 807 outShape[0] = lhsShape.getDimSize(0); 808 outShape[1] = lhsShape.getDimSize(1); 809 } 810 811 if (rhsShape.hasRank()) { 812 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) 813 : outShape[0]; 814 outShape[2] = rhsShape.getDimSize(2); 815 } 816 817 inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 818 return success(); 819 } 820 821 LogicalResult tosa::PadOp::inferReturnTypeComponents( 822 MLIRContext *context, ::std::optional<Location> location, 823 PadOp::Adaptor adaptor, 824 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 825 ShapeAdaptor inputShape(adaptor.getInput1().getType()); 826 auto paddingRank = 827 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank(); 828 SmallVector<int64_t> outputShape; 829 830 // If the input rank is unknown, we can infer the output rank using the 831 // padding shape's rank divided by 2. 832 if (!inputShape.hasRank()) { 833 outputShape.resize(paddingRank / 2, ShapedType::kDynamic); 834 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 835 return success(); 836 } 837 838 SmallVector<int64_t> paddingValues; 839 // If the paddings value is not a constant, all dimensions must be dynamic. 840 if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(), 841 paddingValues)) { 842 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); 843 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 844 return success(); 845 } 846 847 outputShape.reserve(inputShape.getRank()); 848 for (int i = 0, s = inputShape.getRank(); i < s; i++) { 849 if (inputShape.isDynamicDim(i)) { 850 outputShape.push_back(ShapedType::kDynamic); 851 continue; 852 } 853 auto padFront = paddingValues[i * 2]; 854 auto padBack = paddingValues[i * 2 + 1]; 855 if (padFront < 0 || padBack < 0) { 856 // if either padding for dim i is -1, output dim is unknown 857 outputShape.push_back(ShapedType::kDynamic); 858 continue; 859 } 860 861 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack); 862 } 863 864 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 865 return success(); 866 } 867 868 LogicalResult tosa::PadOp::verify() { 869 RankedTensorType inputType = getInput1().getType(); 870 RankedTensorType outputType = getOutput().getType(); 871 auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank(); 872 873 if (inputType.getRank() != outputType.getRank()) 874 return emitOpError() << "expect same input and output tensor rank."; 875 876 if (paddingRank != inputType.getRank() * 2) 877 return emitOpError() << "expected padding tensor dim 0 to have size " 878 << inputType.getRank() * 2 879 << " (2*rank(shape1)) but got size " << paddingRank; 880 881 return success(); 882 } 883 884 static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) { 885 return to_vector(llvm::map_range(shape, [](int64_t dim) { 886 return dim == -1 ? ShapedType::kDynamic : dim; 887 })); 888 } 889 890 LogicalResult tosa::SliceOp::inferReturnTypeComponents( 891 MLIRContext *context, ::std::optional<Location> location, 892 SliceOp::Adaptor adaptor, 893 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 894 895 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); 896 SmallVector<int64_t> start; 897 SmallVector<int64_t> size; 898 899 if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) || 900 !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) { 901 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank(); 902 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic); 903 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); 904 return success(); 905 } 906 907 // if size[i] is -1, all remaining elements in dimension i are included 908 // in the slice, similar to TF. 909 ShapeAdaptor inputShape(adaptor.getInput1().getType()); 910 // initialize outputShape to all unknown 911 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic); 912 if (inputShape.hasRank()) { 913 for (size_t i = 0; i < size.size(); i++) { 914 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 && 915 (ShapedType::isDynamic(inputShape.getDimSize(i)) || 916 start[i] < inputShape.getDimSize(i))) { 917 // size[i] is not 0 and not < -1, and start[i] is in valid range 918 if (ShapedType::isDynamic(inputShape.getDimSize(i))) { 919 // input shape has unknown dim[i] - only valid if size[i] > 0 920 if (size[i] > 0) { 921 outputShape[i] = size[i]; 922 } 923 } else { 924 // input shape has known dim[i] 925 if (size[i] == -1) { 926 outputShape[i] = inputShape.getDimSize(i) - start[i]; 927 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) { 928 // start[i] + size[i] is within bound of input shape's dim[i] 929 outputShape[i] = size[i]; 930 } 931 } 932 } 933 } 934 } else { 935 outputShape = convertToMlirShape(size); 936 } 937 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 938 return success(); 939 } 940 941 LogicalResult tosa::SliceOp::verify() { 942 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 943 if (!inputType) 944 return success(); 945 946 auto startShapeRank = 947 llvm::cast<tosa::shapeType>(getStart().getType()).getRank(); 948 if (inputType.getRank() != startShapeRank) 949 return emitOpError( 950 "length of start attribute is not equal rank of input shape"); 951 952 auto sizeShapeRank = 953 llvm::cast<tosa::shapeType>(getSize().getType()).getRank(); 954 if (inputType.getRank() != sizeShapeRank) 955 return emitOpError( 956 "length of size attribute is not equal rank of input shape"); 957 958 return success(); 959 } 960 961 LogicalResult tosa::MulOp::verify() { 962 auto resElemType = getElementTypeOrSelf(getOutput()); 963 964 // Verify if the element type among operands and result match tosa 965 // specification. 966 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) { 967 IntegerType lhsIntType = 968 cast<IntegerType>(getElementTypeOrSelf(getInput1())); 969 IntegerType rhsIntType = 970 cast<IntegerType>(getElementTypeOrSelf(getInput2())); 971 if (lhsIntType != rhsIntType) 972 return emitOpError("requires the same element type for all operands"); 973 974 // Though the spec requires the element type of result to be i32, a more 975 // relaxed way is provided at dialect level for easier cooperating with 976 // other dialects. 977 if (lhsIntType.getWidth() > resIntType.getWidth()) 978 return emitOpError("invalid data type size for operands or result"); 979 980 } else { 981 // For other supported type, the spec requires requires the same element 982 // type for all operands (excludes `shift` operand) and results. 983 for (int i = 0; i < 2; ++i) { 984 if (getElementTypeOrSelf(getOperand(i)) != resElemType) 985 return emitOpError( 986 "requires the same element type for all operands and results"); 987 } 988 } 989 990 // Verify the op has same ranks for all main operands (excludes extra operands 991 // such as shift of mul op, so this is the only difference with the built-in 992 // `SameOperandsAndResultRank` trait) and results types, if known. 993 994 // delegate function that returns true if type is a shaped type with known 995 // rank 996 auto hasRank = [](const Type type) { 997 if (auto shaped_type = dyn_cast<ShapedType>(type)) 998 return shaped_type.hasRank(); 999 1000 return false; 1001 }; 1002 1003 auto rankedOperandTypes = 1004 llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank)); 1005 1006 auto rankedResultTypes = 1007 llvm::make_filter_range(getOperation()->getResultTypes(), hasRank); 1008 1009 // If all operands and results are unranked, then no further verification. 1010 if (rankedOperandTypes.empty() && rankedResultTypes.empty()) 1011 return success(); 1012 1013 // delegate function that returns rank of shaped type with known rank 1014 auto getRank = [](const Type type) { 1015 return cast<ShapedType>(type).getRank(); 1016 }; 1017 1018 auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) 1019 : getRank(*rankedResultTypes.begin()); 1020 1021 for (size_t i = 0; i < 2; ++i) { 1022 if (rank != getRank(rankedOperandTypes[i])) { 1023 return emitOpError("operands don't have matching ranks"); 1024 } 1025 } 1026 1027 for (const auto type : rankedResultTypes) { 1028 if (rank != getRank(type)) { 1029 return emitOpError("result type has different rank than operands"); 1030 } 1031 } 1032 1033 return success(); 1034 } 1035 1036 LogicalResult tosa::TableOp::inferReturnTypeComponents( 1037 MLIRContext *context, ::std::optional<Location> location, 1038 TableOp::Adaptor adaptor, 1039 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1040 ShapeAdaptor inputShape(adaptor.getInput1().getType()); 1041 1042 if (!inputShape.hasRank()) { 1043 inferredReturnShapes.push_back(ShapedTypeComponents()); 1044 return success(); 1045 } 1046 1047 inferredReturnShapes.resize(1); 1048 inputShape.getDims(inferredReturnShapes[0]); 1049 return success(); 1050 } 1051 1052 LogicalResult tosa::TableOp::verify() { 1053 TensorType inputType = getInput1().getType(); 1054 TensorType outputType = getOutput().getType(); 1055 1056 if (inputType.hasRank() && outputType.hasRank() && 1057 inputType.getRank() != outputType.getRank()) 1058 return emitOpError() 1059 << "expected input tensor rank to equal result tensor rank"; 1060 1061 auto inputDims = inputType.getShape(); 1062 auto outputDims = outputType.getShape(); 1063 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) { 1064 int64_t dim = it.index(); 1065 auto [inputDim, outputDim] = it.value(); 1066 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) { 1067 return emitOpError() << "dim(result, " << dim << ") = " << outputDim 1068 << " doesn't match dim(input, " << dim 1069 << ") = " << inputDim; 1070 } 1071 } 1072 return success(); 1073 } 1074 1075 LogicalResult 1076 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) { 1077 // Multiples must be constants. 1078 DenseIntElementsAttr multiplesAttr; 1079 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) 1080 return failure(); 1081 multiples = llvm::to_vector( 1082 llvm::map_range(multiplesAttr.getValues<APInt>(), 1083 [](const APInt &val) { return val.getSExtValue(); })); 1084 return success(); 1085 } 1086 1087 LogicalResult tosa::TileOp::inferReturnTypeComponents( 1088 MLIRContext *context, ::std::optional<Location> location, 1089 TileOp::Adaptor adaptor, 1090 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1091 DenseIntElementsAttr multiplesAttr; 1092 if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr))) 1093 return failure(); 1094 1095 SmallVector<int64_t> multiples = llvm::to_vector( 1096 llvm::map_range(multiplesAttr.getValues<APInt>(), 1097 [](const APInt &val) { return val.getSExtValue(); })); 1098 1099 ShapeAdaptor inputShape(adaptor.getInput1().getType()); 1100 SmallVector<int64_t> outputShape; 1101 if (!inputShape.hasRank()) { 1102 outputShape.resize(multiples.size(), ShapedType::kDynamic); 1103 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1104 return success(); 1105 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size()) 1106 return failure(); 1107 1108 // Any non dynamic dimension can be multiplied to a known size. 1109 outputShape.reserve(multiples.size()); 1110 for (int i = 0, s = inputShape.getRank(); i < s; i++) { 1111 int64_t dim = inputShape.getDimSize(i); 1112 if (dim != ShapedType::kDynamic) 1113 dim *= multiples[i]; 1114 outputShape.push_back(dim); 1115 } 1116 1117 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1118 return success(); 1119 } 1120 1121 LogicalResult tosa::TileOp::verify() { 1122 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType()); 1123 ShapedType outputType = llvm::cast<ShapedType>(getType()); 1124 1125 shapeType multiplesType = 1126 llvm::cast<tosa::shapeType>(getMultiples().getType()); 1127 1128 auto multiplesRank = multiplesType.getRank(); 1129 1130 if (inputType.hasRank()) { 1131 if (inputType.getRank() != multiplesRank) 1132 return emitOpError("expect 'multiples' to have rank ") 1133 << inputType.getRank() << " but got " << multiplesRank << "."; 1134 if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) 1135 return emitOpError("expect same input and output tensor rank."); 1136 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) 1137 return emitOpError("expect 'multiples' array to have length ") 1138 << outputType.getRank() << " but got " << multiplesRank << "."; 1139 1140 SmallVector<int64_t> multiples; 1141 if (getConstantMultiples(multiples).succeeded() && 1142 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) 1143 return emitOpError( 1144 "expect element of 'multiples' to be positive integer or -1."); 1145 1146 return success(); 1147 } 1148 1149 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1150 if (l.size() != r.size() || l.size() != 1) 1151 return false; 1152 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); 1153 } 1154 1155 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( 1156 MLIRContext *context, ::std::optional<Location> location, 1157 ReshapeOp::Adaptor adaptor, 1158 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1159 ShapeAdaptor inputShape(adaptor.getInput1().getType()); 1160 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); 1161 llvm::SmallVector<int64_t> newShapeValue = 1162 convertToMlirShape(adaptor.getNewShape()); 1163 1164 // We cannot infer from the total number of elements so we must take the 1165 // shape attribute as exact. 1166 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { 1167 inferredReturnShapes.push_back( 1168 ShapedTypeComponents(newShapeValue, inputType)); 1169 return success(); 1170 } 1171 1172 // Determine the number of elements covered by the slice of all static 1173 // dimensions. This allows us to infer the length of the remaining dynamic 1174 // dimension. 1175 int64_t numElements = inputShape.getNumElements(); 1176 int64_t staticMul = 1; 1177 for (auto val : newShapeValue) { 1178 if (!ShapedType::isDynamic(val)) { 1179 staticMul *= val; 1180 } 1181 } 1182 1183 // Determine the length of the dynamic dimension. 1184 for (auto &val : newShapeValue) { 1185 if (ShapedType::isDynamic(val)) 1186 val = numElements / staticMul; 1187 } 1188 1189 inferredReturnShapes.push_back( 1190 ShapedTypeComponents(newShapeValue, inputType)); 1191 return success(); 1192 } 1193 1194 llvm::LogicalResult tosa::ReshapeOp::verify() { 1195 TensorType inputType = getInput1().getType(); 1196 RankedTensorType outputType = getType(); 1197 1198 if ((int64_t)getNewShape().size() != outputType.getRank()) 1199 return emitOpError() << "new shape does not match result rank"; 1200 1201 for (auto [newShapeDim, outputShapeDim] : 1202 zip(getNewShape(), outputType.getShape())) { 1203 if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic && 1204 newShapeDim != outputShapeDim) 1205 return emitOpError() << "new shape is inconsistent with result shape"; 1206 1207 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1) 1208 return emitOpError() << "new shape has invalid tensor dimension size " 1209 << newShapeDim; 1210 } 1211 1212 if (inputType.hasStaticShape()) { 1213 int64_t inputElementsNum = inputType.getNumElements(); 1214 if (outputType.hasStaticShape()) { 1215 int64_t outputElementsNum = outputType.getNumElements(); 1216 if (inputElementsNum != outputElementsNum) { 1217 return emitOpError() << "cannot reshape " << inputElementsNum 1218 << " elements into " << outputElementsNum; 1219 } 1220 } 1221 1222 int64_t newShapeElementsNum = std::accumulate( 1223 getNewShape().begin(), getNewShape().end(), 1LL, 1224 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); 1225 bool isStaticNewShape = 1226 llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; }); 1227 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || 1228 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { 1229 return emitOpError() << "cannot reshape " << inputElementsNum 1230 << " elements into " << newShapeElementsNum; 1231 } 1232 } 1233 1234 int missingDims = llvm::count(getNewShape(), -1); 1235 if (missingDims > 1) 1236 return emitOpError() << "expected at most one target dimension to be -1"; 1237 1238 return mlir::success(); 1239 } 1240 1241 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) { 1242 // Perms must be constants. 1243 DenseIntElementsAttr permsAttr; 1244 if (!matchPattern(getPerms(), m_Constant(&permsAttr))) 1245 return failure(); 1246 1247 perms.clear(); 1248 for (auto v : permsAttr.getValues<APInt>()) 1249 perms.push_back(v.getSExtValue()); 1250 1251 return success(); 1252 } 1253 1254 LogicalResult tosa::TransposeOp::inferReturnTypeComponents( 1255 MLIRContext *context, ::std::optional<Location> location, 1256 TransposeOp::Adaptor adaptor, 1257 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1258 ShapeAdaptor inputShape(adaptor.getInput1().getType()); 1259 ShapeAdaptor permsShape(adaptor.getPerms().getType()); 1260 1261 // We cannot infer anything from a rank-0 "permutation" tensor. 1262 if (permsShape.hasRank() && permsShape.getRank() == 0) 1263 return failure(); 1264 1265 // If input rank and permutation length is unknown, the output rank is 1266 // unknown. 1267 if (!inputShape.hasRank() || !permsShape.hasRank() || 1268 permsShape.isDynamicDim(0)) { 1269 inferredReturnShapes.push_back(ShapedTypeComponents()); 1270 return success(); 1271 } 1272 1273 // This would imply the number of permutations does not match the rank of 1274 // the input which is illegal. 1275 if (permsShape.getDimSize(0) != inputShape.getRank()) { 1276 return failure(); 1277 } 1278 1279 SmallVector<int64_t> outputShape; 1280 // Rank-0 means no permutations matter. 1281 if (inputShape.getRank() == 0) { 1282 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1283 return success(); 1284 } 1285 1286 // Check whether the input dimensions are all the same. 1287 bool allTheSame = true; 1288 for (int i = 1, s = inputShape.getRank(); i < s; i++) { 1289 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { 1290 allTheSame = false; 1291 break; 1292 } 1293 } 1294 1295 // If all of the input dimensions are the same we don't care about the 1296 // permutation. 1297 if (allTheSame) { 1298 outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); 1299 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1300 return success(); 1301 } 1302 1303 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); 1304 // If the permuations are a constant we can directly determine the output 1305 // shape. 1306 DenseIntElementsAttr attr; 1307 if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) && 1308 attr.getType().getRank() == 1) { 1309 ShapeAdaptor permShape = attr; 1310 // Constant permutation must be the same length as the input rank. 1311 if (inputShape.getRank() != permShape.getRank()) 1312 return emitOptionalError(location, 1313 "constant permutation must be the same length" 1314 " as the input rank"); 1315 1316 // Constant permutation values must be within the input rank. 1317 for (int i = 0, e = inputShape.getRank(); i < e; i++) { 1318 if (inputShape.getRank() <= permShape.getDimSize(i)) 1319 return failure(); 1320 } 1321 1322 outputShape.reserve(inputShape.getRank()); 1323 for (int i = 0, s = inputShape.getRank(); i < s; i++) { 1324 outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); 1325 } 1326 } 1327 1328 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1329 return success(); 1330 } 1331 1332 LogicalResult tosa::TransposeOp::verify() { 1333 TensorType inputType = getInput1().getType(); 1334 TensorType permType = getPerms().getType(); 1335 TensorType outputType = getOutput().getType(); 1336 1337 if (permType.hasRank() && permType.getRank() != 1) 1338 return emitOpError() 1339 << "expected permutation tensor to be rank 1 but got rank " 1340 << permType.getRank(); 1341 if (inputType.hasRank() && permType.hasRank()) 1342 if (!permType.isDynamicDim(0) && 1343 permType.getDimSize(0) != inputType.getRank()) 1344 return emitOpError() << "expected permutation tensor dim 0 to have size " 1345 << inputType.getRank() 1346 << " (input rank) but got size " 1347 << permType.getDimSize(0); 1348 if (inputType.hasRank() && outputType.hasRank() && 1349 inputType.getRank() != outputType.getRank()) 1350 return emitOpError() 1351 << "expected input tensor rank to equal result tensor rank"; 1352 if (outputType.hasRank() && permType.hasRank()) 1353 if (!permType.isDynamicDim(0) && 1354 permType.getDimSize(0) != outputType.getRank()) 1355 return emitOpError() << "expected permutation tensor dim 0 to have size " 1356 << outputType.getRank() 1357 << " (output rank) but got size " 1358 << permType.getDimSize(0); 1359 1360 SmallVector<int32_t> constantPerms; 1361 if (succeeded(getConstantPerms(constantPerms))) { 1362 // Assert that the permutation tensor has a rank, which means that the 1363 // rank has been verified above. 1364 assert(permType.hasRank() && 1365 "Unexpectedly found permutation tensor without rank"); 1366 if (!llvm::all_of(constantPerms, 1367 [&constantPerms](int32_t s) { 1368 return s >= 0 && 1369 static_cast<size_t>(s) < constantPerms.size(); 1370 }) || 1371 !isPermutationVector(llvm::to_vector(llvm::map_range( 1372 constantPerms, [](int32_t v) -> int64_t { return v; })))) 1373 return emitOpError() << "expected valid permutation tensor"; 1374 1375 // Verify that the types of the input and output tensors are properly 1376 // permuted. 1377 if (inputType.hasRank() && outputType.hasRank()) { 1378 assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) && 1379 inputType.getRank() == outputType.getRank()); 1380 1381 for (auto i = 0; i < outputType.getRank(); i++) { 1382 if (inputType.isDynamicDim(constantPerms[i]) || 1383 outputType.isDynamicDim(i)) 1384 continue; 1385 1386 if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i)) 1387 return emitOpError() 1388 << "expected output tensor dim " << i << " to match " 1389 << "input dim " << constantPerms[i] << " with value of " 1390 << inputType.getDimSize(constantPerms[i]); 1391 } 1392 } 1393 } 1394 return success(); 1395 } 1396 1397 LogicalResult TransposeOp::reifyResultShapes( 1398 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1399 1400 SmallVector<int32_t> transposePerms; 1401 if (getConstantPerms(transposePerms).failed()) 1402 return failure(); 1403 1404 Value input = getInput1(); 1405 auto inputType = cast<TensorType>(input.getType()); 1406 1407 SmallVector<OpFoldResult> returnedDims(inputType.getRank()); 1408 for (auto dim : transposePerms) { 1409 int32_t dimInInput = transposePerms[dim]; 1410 if (inputType.isDynamicDim(dimInInput)) 1411 returnedDims[dim] = 1412 builder.create<tensor::DimOp>(getLoc(), input, dimInInput) 1413 .getResult(); 1414 else 1415 returnedDims[dim] = 1416 builder.getIndexAttr(inputType.getDimSize(dimInInput)); 1417 } 1418 1419 reifiedReturnShapes.emplace_back(std::move(returnedDims)); 1420 return success(); 1421 } 1422 1423 LogicalResult tosa::GatherOp::inferReturnTypeComponents( 1424 MLIRContext *context, ::std::optional<Location> location, 1425 GatherOp::Adaptor adaptor, 1426 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1427 llvm::SmallVector<int64_t> outputShape; 1428 outputShape.resize(3, ShapedType::kDynamic); 1429 1430 ShapeAdaptor valuesShape(adaptor.getValues().getType()); 1431 if (valuesShape.hasRank()) { 1432 outputShape[0] = valuesShape.getDimSize(0); 1433 outputShape[2] = valuesShape.getDimSize(2); 1434 } 1435 1436 ShapeAdaptor indicesShape(adaptor.getIndices().getType()); 1437 if (indicesShape.hasRank()) { 1438 if (outputShape[0] == ShapedType::kDynamic) 1439 outputShape[0] = indicesShape.getDimSize(0); 1440 if (outputShape[1] == ShapedType::kDynamic) 1441 outputShape[1] = indicesShape.getDimSize(1); 1442 } 1443 1444 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1445 return success(); 1446 } 1447 1448 LogicalResult tosa::ResizeOp::inferReturnTypeComponents( 1449 MLIRContext *context, ::std::optional<Location> location, 1450 ResizeOp::Adaptor adaptor, 1451 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1452 llvm::SmallVector<int64_t, 4> outputShape; 1453 outputShape.resize(4, ShapedType::kDynamic); 1454 1455 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1456 if (!inputShape.hasRank()) 1457 return failure(); 1458 1459 outputShape[0] = inputShape.getDimSize(0); 1460 outputShape[3] = inputShape.getDimSize(3); 1461 int64_t inputHeight = inputShape.getDimSize(1); 1462 int64_t inputWidth = inputShape.getDimSize(2); 1463 1464 if ((inputHeight == ShapedType::kDynamic) || 1465 (inputWidth == ShapedType::kDynamic)) 1466 return failure(); 1467 1468 llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale(); 1469 llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset(); 1470 llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder(); 1471 1472 // Compute the output shape based on attributes: scale, offset, and border. 1473 outputShape[1] = 1474 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / 1475 scaleInt[1]) + 1476 1; 1477 1478 outputShape[2] = 1479 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / 1480 scaleInt[3]) + 1481 1; 1482 1483 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1484 return success(); 1485 } 1486 1487 LogicalResult tosa::ScatterOp::inferReturnTypeComponents( 1488 MLIRContext *context, ::std::optional<Location> location, 1489 ScatterOp::Adaptor adaptor, 1490 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1491 llvm::SmallVector<int64_t> outputShape; 1492 outputShape.resize(3, ShapedType::kDynamic); 1493 1494 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType()); 1495 if (valuesInShape.hasRank()) { 1496 outputShape[0] = valuesInShape.getDimSize(0); 1497 outputShape[1] = valuesInShape.getDimSize(1); 1498 outputShape[2] = valuesInShape.getDimSize(2); 1499 } 1500 1501 ShapeAdaptor indicesShape(adaptor.getIndices().getType()); 1502 if (indicesShape.hasRank()) { 1503 if (outputShape[0] == ShapedType::kDynamic) 1504 outputShape[0] = indicesShape.getDimSize(0); 1505 } 1506 1507 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1508 if (inputShape.hasRank()) { 1509 if (outputShape[0] == ShapedType::kDynamic) 1510 outputShape[0] = inputShape.getDimSize(0); 1511 if (outputShape[2] == ShapedType::kDynamic) 1512 outputShape[2] = inputShape.getDimSize(2); 1513 } 1514 1515 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1516 return success(); 1517 } 1518 1519 static LogicalResult ReduceInferReturnTypes( 1520 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, 1521 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1522 int64_t axisVal = axis.getValue().getSExtValue(); 1523 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) { 1524 inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); 1525 return success(); 1526 } 1527 1528 SmallVector<int64_t> outputShape; 1529 operandShape.getDims(outputShape); 1530 outputShape[axisVal] = 1; 1531 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); 1532 return success(); 1533 } 1534 1535 #define COMPATIBLE_RETURN_TYPES(OP) \ 1536 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ 1537 if (l.size() != r.size() || l.size() != 1) \ 1538 return false; \ 1539 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ 1540 return false; \ 1541 return succeeded(verifyCompatibleShape(l[0], r[0])); \ 1542 } 1543 1544 #define REDUCE_SHAPE_INFER(OP) \ 1545 LogicalResult OP::inferReturnTypeComponents( \ 1546 MLIRContext *context, ::std::optional<Location> location, \ 1547 OP::Adaptor adaptor, \ 1548 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ 1549 Type inputType = \ 1550 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \ 1551 ShapeAdaptor inputShape(adaptor.getInput().getType()); \ 1552 const Properties &prop = adaptor.getProperties(); \ 1553 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ 1554 inferredReturnShapes); \ 1555 } \ 1556 COMPATIBLE_RETURN_TYPES(OP) 1557 1558 REDUCE_SHAPE_INFER(tosa::ReduceAllOp) 1559 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) 1560 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) 1561 REDUCE_SHAPE_INFER(tosa::ReduceMinOp) 1562 REDUCE_SHAPE_INFER(tosa::ReduceProdOp) 1563 REDUCE_SHAPE_INFER(tosa::ReduceSumOp) 1564 #undef REDUCE_SHAPE_INFER 1565 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) 1566 #undef COMPATIBLE_RETURN_TYPES 1567 1568 template <typename T> 1569 static LogicalResult verifyReduceOp(T op) { 1570 // All TOSA reduce Ops have input, output and axis. 1571 TensorType inputType = op.getInput().getType(); 1572 TensorType outputType = op.getOutput().getType(); 1573 int32_t reduceAxis = op.getAxis(); 1574 1575 if (reduceAxis < 0) { 1576 op.emitOpError("reduce axis must not be negative"); 1577 return failure(); 1578 } 1579 if (inputType.hasRank()) { 1580 int64_t inputRank = inputType.getRank(); 1581 // We allow for a special case where the input/output shape has rank 0 and 1582 // axis is also 0. 1583 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) { 1584 op.emitOpError("expect input tensor rank (") 1585 << inputRank << ") to be larger than reduce axis (" << reduceAxis 1586 << ")"; 1587 return failure(); 1588 } 1589 } 1590 if (outputType.hasRank()) { 1591 int64_t outputRank = outputType.getRank(); 1592 if (inputType.hasRank() && outputRank != inputType.getRank()) { 1593 op.emitOpError( 1594 "expect output tensor rank to be equal to input tensor rank"); 1595 return failure(); 1596 } 1597 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) { 1598 op.emitOpError("expect output tensor rank (") 1599 << outputRank << ") to be larger than reduce axis (" << reduceAxis 1600 << ")"; 1601 return failure(); 1602 } 1603 // We can only verify the reduced dimension size to be 1 if this is not 1604 // the special case of output rank == 0. 1605 if (outputRank != 0) { 1606 auto outputShape = outputType.getShape(); 1607 if (!outputType.isDynamicDim(reduceAxis) && 1608 outputShape[reduceAxis] != 1) { 1609 op.emitOpError("expect reduced dimension size to be 1, got ") 1610 << outputShape[reduceAxis]; 1611 return failure(); 1612 } 1613 } 1614 } 1615 return success(); 1616 } 1617 1618 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); } 1619 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); } 1620 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); } 1621 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); } 1622 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); } 1623 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); } 1624 1625 static LogicalResult NAryInferReturnTypes( 1626 const ValueShapeRange &operands, 1627 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1628 llvm::SmallVector<int64_t> outShape; 1629 if (resolveBroadcastShape(operands, outShape).failed()) { 1630 inferredReturnShapes.push_back(ShapedTypeComponents()); 1631 } else { 1632 inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 1633 } 1634 return success(); 1635 } 1636 1637 #define NARY_SHAPE_INFER(OP) \ 1638 LogicalResult OP::inferReturnTypeComponents( \ 1639 MLIRContext *context, ::std::optional<Location> location, \ 1640 ValueShapeRange operands, DictionaryAttr attributes, \ 1641 OpaqueProperties properties, RegionRange regions, \ 1642 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ 1643 return NAryInferReturnTypes(operands, inferredReturnShapes); \ 1644 } 1645 1646 NARY_SHAPE_INFER(tosa::AbsOp) 1647 NARY_SHAPE_INFER(tosa::AddOp) 1648 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) 1649 NARY_SHAPE_INFER(tosa::BitwiseAndOp) 1650 NARY_SHAPE_INFER(tosa::BitwiseOrOp) 1651 NARY_SHAPE_INFER(tosa::BitwiseXorOp) 1652 NARY_SHAPE_INFER(tosa::BitwiseNotOp) 1653 NARY_SHAPE_INFER(tosa::CastOp) 1654 NARY_SHAPE_INFER(tosa::CeilOp) 1655 NARY_SHAPE_INFER(tosa::ClampOp) 1656 NARY_SHAPE_INFER(tosa::ClzOp) 1657 NARY_SHAPE_INFER(tosa::CosOp) 1658 NARY_SHAPE_INFER(tosa::ExpOp) 1659 NARY_SHAPE_INFER(tosa::FloorOp) 1660 NARY_SHAPE_INFER(tosa::GreaterEqualOp) 1661 NARY_SHAPE_INFER(tosa::GreaterOp) 1662 NARY_SHAPE_INFER(tosa::IdentityOp) 1663 NARY_SHAPE_INFER(tosa::IntDivOp) 1664 NARY_SHAPE_INFER(tosa::LogOp) 1665 NARY_SHAPE_INFER(tosa::LogicalAndOp) 1666 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) 1667 NARY_SHAPE_INFER(tosa::LogicalNotOp) 1668 NARY_SHAPE_INFER(tosa::LogicalOrOp) 1669 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) 1670 NARY_SHAPE_INFER(tosa::LogicalXorOp) 1671 NARY_SHAPE_INFER(tosa::MaximumOp) 1672 NARY_SHAPE_INFER(tosa::MinimumOp) 1673 NARY_SHAPE_INFER(tosa::MulOp) 1674 NARY_SHAPE_INFER(tosa::NegateOp) 1675 NARY_SHAPE_INFER(tosa::PowOp) 1676 NARY_SHAPE_INFER(tosa::ReciprocalOp) 1677 NARY_SHAPE_INFER(tosa::RescaleOp) 1678 NARY_SHAPE_INFER(tosa::ReverseOp) 1679 NARY_SHAPE_INFER(tosa::RsqrtOp) 1680 NARY_SHAPE_INFER(tosa::SinOp) 1681 NARY_SHAPE_INFER(tosa::SelectOp) 1682 NARY_SHAPE_INFER(tosa::SubOp) 1683 NARY_SHAPE_INFER(tosa::TanhOp) 1684 NARY_SHAPE_INFER(tosa::ErfOp) 1685 NARY_SHAPE_INFER(tosa::SigmoidOp) 1686 #undef PRED_SHAPE_INFER 1687 1688 static LogicalResult poolingInferReturnTypes( 1689 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, 1690 ArrayRef<int64_t> pad, 1691 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1692 llvm::SmallVector<int64_t> outputShape; 1693 outputShape.resize(4, ShapedType::kDynamic); 1694 1695 // We only know the rank if the input type is unranked. 1696 if (!inputShape) { 1697 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1698 return success(); 1699 } 1700 1701 // Batch and number of channels are identical for pooling layer. 1702 outputShape[0] = inputShape.getDimSize(0); 1703 outputShape[3] = inputShape.getDimSize(3); 1704 1705 int64_t height = inputShape.getDimSize(1); 1706 int64_t width = inputShape.getDimSize(2); 1707 1708 if (!ShapedType::isDynamic(height)) { 1709 int64_t padded = height + pad[0] + pad[1] - kernel[0]; 1710 outputShape[1] = padded / stride[0] + 1; 1711 } 1712 1713 if (!ShapedType::isDynamic(width)) { 1714 int64_t padded = width + pad[2] + pad[3] - kernel[1]; 1715 outputShape[2] = padded / stride[1] + 1; 1716 } 1717 1718 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1719 return success(); 1720 } 1721 1722 LogicalResult Conv2DOp::inferReturnTypeComponents( 1723 MLIRContext *context, ::std::optional<Location> location, 1724 Conv2DOp::Adaptor adaptor, 1725 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1726 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic); 1727 1728 int64_t inputWidth = ShapedType::kDynamic; 1729 int64_t inputHeight = ShapedType::kDynamic; 1730 int64_t weightWidth = ShapedType::kDynamic; 1731 int64_t weightHeight = ShapedType::kDynamic; 1732 1733 // Input shape describes input width/height and batch. 1734 1735 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1736 if (inputShape.hasRank()) { 1737 outputShape[0] = inputShape.getDimSize(0); 1738 inputHeight = inputShape.getDimSize(1); 1739 inputWidth = inputShape.getDimSize(2); 1740 } 1741 1742 // Weight shapes describes the filter width/height and the output channels. 1743 ShapeAdaptor weightShape(adaptor.getWeight().getType()); 1744 if (weightShape.hasRank()) { 1745 outputShape[3] = weightShape.getDimSize(0); 1746 weightHeight = weightShape.getDimSize(1); 1747 weightWidth = weightShape.getDimSize(2); 1748 } 1749 1750 // Bias shape can describe the output channels. 1751 ShapeAdaptor biasShape(adaptor.getBias().getType()); 1752 if (biasShape.hasRank()) { 1753 outputShape[3] = ShapedType::isDynamic(outputShape[3]) 1754 ? biasShape.getDimSize(0) 1755 : outputShape[3]; 1756 } 1757 1758 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); 1759 llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 1760 llvm::ArrayRef<int64_t> padding = adaptor.getPad(); 1761 1762 if (!ShapedType::isDynamic(inputHeight) && 1763 !ShapedType::isDynamic(weightHeight)) { 1764 int64_t inputSize = inputHeight + padding[0] + padding[1]; 1765 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; 1766 int64_t unstridedResult = inputSize - filterSize + 1; 1767 outputShape[1] = (unstridedResult - 1) / stride[0] + 1; 1768 } 1769 1770 if (!ShapedType::isDynamic(inputWidth) && 1771 !ShapedType::isDynamic(weightWidth)) { 1772 int64_t inputSize = inputWidth + padding[2] + padding[3]; 1773 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; 1774 int64_t unstridedResult = inputSize - filterSize + 1; 1775 outputShape[2] = (unstridedResult - 1) / stride[1] + 1; 1776 } 1777 1778 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1779 return success(); 1780 } 1781 1782 LogicalResult Conv2DOp::verify() { 1783 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 1784 return failure(); 1785 return success(); 1786 } 1787 1788 LogicalResult Conv3DOp::inferReturnTypeComponents( 1789 MLIRContext *context, ::std::optional<Location> location, 1790 Conv3DOp::Adaptor adaptor, 1791 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1792 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic); 1793 1794 int64_t inputWidth = ShapedType::kDynamic; 1795 int64_t inputHeight = ShapedType::kDynamic; 1796 int64_t inputDepth = ShapedType::kDynamic; 1797 1798 int64_t weightWidth = ShapedType::kDynamic; 1799 int64_t weightHeight = ShapedType::kDynamic; 1800 int64_t weightDepth = ShapedType::kDynamic; 1801 1802 // Input shape describes input width/height and batch. 1803 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1804 if (inputShape.hasRank()) { 1805 outputShape[0] = inputShape.getDimSize(0); 1806 inputDepth = inputShape.getDimSize(1); 1807 inputHeight = inputShape.getDimSize(2); 1808 inputWidth = inputShape.getDimSize(3); 1809 } 1810 1811 // Weight shapes describes the filter width/height and the output channels. 1812 ShapeAdaptor weightShape(adaptor.getWeight().getType()); 1813 if (weightShape.hasRank()) { 1814 outputShape[4] = weightShape.getDimSize(0); 1815 weightDepth = weightShape.getDimSize(1); 1816 weightHeight = weightShape.getDimSize(2); 1817 weightWidth = weightShape.getDimSize(3); 1818 } 1819 1820 // Bias shape can describe the output channels. 1821 ShapeAdaptor biasShape(adaptor.getBias().getType()); 1822 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { 1823 outputShape[4] = biasShape.getDimSize(0); 1824 } 1825 1826 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); 1827 llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 1828 llvm::ArrayRef<int64_t> pad = adaptor.getPad(); 1829 1830 if (!ShapedType::isDynamic(inputDepth) && 1831 !ShapedType::isDynamic(weightDepth)) { 1832 int32_t inputSize = inputDepth + pad[0] + pad[1]; 1833 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; 1834 int32_t unstridedResult = inputSize - filterSize + 1; 1835 outputShape[1] = (unstridedResult - 1) / stride[0] + 1; 1836 } 1837 1838 if (!ShapedType::isDynamic(inputHeight) && 1839 !ShapedType::isDynamic(weightHeight)) { 1840 int32_t inputSize = inputHeight + pad[2] + pad[3]; 1841 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; 1842 int32_t unstridedResult = inputSize - filterSize + 1; 1843 outputShape[2] = (unstridedResult - 1) / stride[1] + 1; 1844 } 1845 1846 if (!ShapedType::isDynamic(inputWidth) && 1847 !ShapedType::isDynamic(weightWidth)) { 1848 int32_t inputSize = inputWidth + pad[4] + pad[5]; 1849 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; 1850 int32_t unstridedResult = inputSize - filterSize + 1; 1851 outputShape[3] = (unstridedResult - 1) / stride[2] + 1; 1852 } 1853 1854 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1855 return success(); 1856 } 1857 1858 LogicalResult Conv3DOp::verify() { 1859 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 1860 return failure(); 1861 return success(); 1862 } 1863 1864 LogicalResult AvgPool2dOp::inferReturnTypeComponents( 1865 MLIRContext *context, ::std::optional<Location> location, 1866 AvgPool2dOp::Adaptor adaptor, 1867 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1868 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1869 const Properties &prop = adaptor.getProperties(); 1870 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, 1871 inferredReturnShapes); 1872 } 1873 1874 LogicalResult MaxPool2dOp::inferReturnTypeComponents( 1875 MLIRContext *context, ::std::optional<Location> location, 1876 MaxPool2dOp::Adaptor adaptor, 1877 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1878 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1879 const Properties &prop = adaptor.getProperties(); 1880 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, 1881 inferredReturnShapes); 1882 } 1883 1884 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( 1885 MLIRContext *context, ::std::optional<Location> location, 1886 DepthwiseConv2DOp::Adaptor adaptor, 1887 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1888 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic); 1889 1890 int64_t inputWidth = ShapedType::kDynamic; 1891 int64_t inputHeight = ShapedType::kDynamic; 1892 int64_t inputChannels = ShapedType::kDynamic; 1893 1894 int64_t weightWidth = ShapedType::kDynamic; 1895 int64_t weightHeight = ShapedType::kDynamic; 1896 int64_t depthChannels = ShapedType::kDynamic; 1897 1898 // Input shape describes input width/height and batch. 1899 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1900 if (inputShape.hasRank()) { 1901 outputShape[0] = inputShape.getDimSize(0); 1902 inputHeight = inputShape.getDimSize(1); 1903 inputWidth = inputShape.getDimSize(2); 1904 inputChannels = inputShape.getDimSize(3); 1905 } 1906 1907 // Weight shapes describes the filter width/height and the output channels. 1908 ShapeAdaptor weightShape(adaptor.getWeight().getType()); 1909 if (weightShape.hasRank()) { 1910 weightHeight = weightShape.getDimSize(0); 1911 weightWidth = weightShape.getDimSize(1); 1912 inputChannels = ShapedType::isDynamic(inputChannels) 1913 ? weightShape.getDimSize(2) 1914 : inputChannels; 1915 depthChannels = weightShape.getDimSize(3); 1916 } 1917 1918 // If both inputChannels and depthChannels are available we can determine 1919 // the output channels. 1920 if (!ShapedType::isDynamic(inputChannels) && 1921 !ShapedType::isDynamic(depthChannels)) { 1922 outputShape[3] = inputChannels * depthChannels; 1923 } 1924 1925 // Bias shape can describe the output channels. 1926 ShapeAdaptor biasShape(adaptor.getBias().getType()); 1927 if (biasShape.hasRank()) { 1928 outputShape[3] = ShapedType::isDynamic(outputShape[3]) 1929 ? biasShape.getDimSize(0) 1930 : outputShape[3]; 1931 } 1932 1933 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); 1934 llvm::ArrayRef<int64_t> padding = adaptor.getPad(); 1935 llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 1936 1937 if (!ShapedType::isDynamic(inputHeight) && 1938 !ShapedType::isDynamic(weightHeight)) { 1939 int64_t inputSize = inputHeight + padding[0] + padding[1]; 1940 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; 1941 int64_t unstridedResult = inputSize - filterSize + 1; 1942 outputShape[1] = (unstridedResult - 1) / stride[0] + 1; 1943 } 1944 1945 if (!ShapedType::isDynamic(inputWidth) && 1946 !ShapedType::isDynamic(weightWidth)) { 1947 int64_t inputSize = inputWidth + padding[2] + padding[3]; 1948 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; 1949 int64_t unstridedResult = inputSize - filterSize + 1; 1950 outputShape[2] = (unstridedResult - 1) / stride[1] + 1; 1951 } 1952 1953 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1954 return success(); 1955 } 1956 1957 LogicalResult DepthwiseConv2DOp::verify() { 1958 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 1959 return failure(); 1960 return success(); 1961 } 1962 1963 LogicalResult TransposeConv2DOp::inferReturnTypeComponents( 1964 MLIRContext *context, ::std::optional<Location> location, 1965 TransposeConv2DOp::Adaptor adaptor, 1966 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1967 // outputShape is mutable. 1968 llvm::SmallVector<int64_t> outputShape = 1969 convertToMlirShape(adaptor.getOutShape()); 1970 1971 int64_t inputWidth = ShapedType::kDynamic; 1972 int64_t inputHeight = ShapedType::kDynamic; 1973 int64_t weightWidth = ShapedType::kDynamic; 1974 int64_t weightHeight = ShapedType::kDynamic; 1975 1976 // Input shape describes input width/height and batch. 1977 ShapeAdaptor inputShape(adaptor.getInput().getType()); 1978 if (inputShape.hasRank()) { 1979 outputShape[0] = ShapedType::isDynamic(outputShape[0]) 1980 ? inputShape.getDimSize(0) 1981 : outputShape[0]; 1982 inputHeight = inputShape.getDimSize(1); 1983 inputWidth = inputShape.getDimSize(2); 1984 } 1985 1986 // Weight shapes describes the filter width/height and the output channels. 1987 ShapeAdaptor weightShape(adaptor.getFilter().getType()); 1988 if (weightShape.hasRank()) { 1989 outputShape[3] = ShapedType::isDynamic(outputShape[3]) 1990 ? weightShape.getDimSize(0) 1991 : outputShape[3]; 1992 weightHeight = weightShape.getDimSize(1); 1993 weightWidth = weightShape.getDimSize(2); 1994 } 1995 1996 // Bias shape can describe the output channels. 1997 ShapeAdaptor biasShape(adaptor.getInput().getType()); 1998 if (biasShape.hasRank()) { 1999 outputShape[3] = ShapedType::isDynamic(outputShape[3]) 2000 ? biasShape.getDimSize(0) 2001 : outputShape[3]; 2002 } 2003 2004 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad(); 2005 llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 2006 2007 if (!ShapedType::isDynamic(inputHeight) && 2008 !ShapedType::isDynamic(weightHeight)) { 2009 int64_t calculateSize = 2010 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; 2011 outputShape[1] = 2012 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; 2013 } 2014 2015 if (!ShapedType::isDynamic(inputWidth) && 2016 !ShapedType::isDynamic(weightWidth)) { 2017 int64_t calculateSize = 2018 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; 2019 outputShape[2] = 2020 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; 2021 } 2022 2023 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 2024 return success(); 2025 } 2026 2027 LogicalResult TransposeConv2DOp::verify() { 2028 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 2029 return failure(); 2030 return success(); 2031 } 2032 2033 LogicalResult IfOp::inferReturnTypeComponents( 2034 MLIRContext *context, ::std::optional<Location> location, 2035 IfOp::Adaptor adaptor, 2036 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 2037 llvm::SmallVector<tosa::YieldOp> yieldOps; 2038 for (Region *region : adaptor.getRegions()) { 2039 for (auto &block : *region) 2040 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator())) 2041 yieldOps.push_back(returnOp); 2042 } 2043 2044 if (yieldOps.empty()) 2045 return failure(); 2046 2047 // Get the initial type information for the yield op. 2048 llvm::SmallVector<ValueKnowledge> resultKnowledge; 2049 resultKnowledge.reserve(yieldOps.front().getNumOperands()); 2050 for (auto operand : yieldOps.front().getOperands()) { 2051 resultKnowledge.push_back( 2052 ValueKnowledge::getKnowledgeFromType(operand.getType())); 2053 } 2054 2055 for (auto yieldOp : yieldOps) { 2056 if (resultKnowledge.size() != yieldOp.getNumOperands()) 2057 return failure(); 2058 2059 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 2060 int32_t index = it.index(); 2061 auto meet = ValueKnowledge::meet( 2062 resultKnowledge[index], 2063 ValueKnowledge::getKnowledgeFromType(it.value().getType())); 2064 if (!meet) 2065 continue; 2066 resultKnowledge[index] = meet; 2067 } 2068 } 2069 2070 for (const ValueKnowledge &result : resultKnowledge) { 2071 inferredReturnShapes.push_back(result.getShapedTypeComponents()); 2072 } 2073 2074 return success(); 2075 } 2076 2077 LogicalResult WhileOp::inferReturnTypeComponents( 2078 MLIRContext *context, ::std::optional<Location> location, 2079 WhileOp::Adaptor adaptor, 2080 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 2081 llvm::SmallVector<tosa::YieldOp> yieldOps; 2082 for (auto &block : adaptor.getBody()) 2083 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator())) 2084 yieldOps.push_back(returnOp); 2085 2086 // TOSA's while must have a tosa.yield as its terminator. If not found this 2087 // tosa.while is invalid. 2088 if (yieldOps.empty()) 2089 return failure(); 2090 2091 // Get the initial type information from the operand types. 2092 llvm::SmallVector<ValueKnowledge> resultKnowledge; 2093 resultKnowledge.reserve(yieldOps.front().getNumOperands()); 2094 for (auto operand : yieldOps.front().getOperands()) { 2095 resultKnowledge.push_back( 2096 ValueKnowledge::getKnowledgeFromType(operand.getType())); 2097 } 2098 2099 for (auto yieldOp : yieldOps) { 2100 if (resultKnowledge.size() != yieldOp.getNumOperands()) 2101 return failure(); 2102 2103 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 2104 int32_t index = it.index(); 2105 if (auto meet = ValueKnowledge::meet( 2106 resultKnowledge[index], 2107 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { 2108 resultKnowledge[index] = meet; 2109 } 2110 } 2111 } 2112 2113 for (const ValueKnowledge &result : resultKnowledge) { 2114 inferredReturnShapes.push_back(result.getShapedTypeComponents()); 2115 } 2116 2117 return success(); 2118 } 2119 2120 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { 2121 if (auto vt = llvm::dyn_cast<VectorType>(getType())) 2122 return llvm::to_vector<4>(vt.getShape()); 2123 return std::nullopt; 2124 } 2125 2126 // parse and print of IfOp refer to the implementation of SCF dialect. 2127 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { 2128 // Create the regions for 'then'. 2129 result.regions.reserve(2); 2130 Region *thenRegion = result.addRegion(); 2131 Region *elseRegion = result.addRegion(); 2132 2133 auto &builder = parser.getBuilder(); 2134 OpAsmParser::UnresolvedOperand cond; 2135 // Create a i1 tensor type for the boolean condition. 2136 Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); 2137 if (parser.parseOperand(cond) || 2138 parser.resolveOperand(cond, i1Type, result.operands)) 2139 return failure(); 2140 // Parse optional results type list. 2141 if (parser.parseOptionalArrowTypeList(result.types)) 2142 return failure(); 2143 // Parse the 'then' region. 2144 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) 2145 return failure(); 2146 2147 // If we find an 'else' keyword then parse the 'else' region. 2148 if (!parser.parseOptionalKeyword("else")) { 2149 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) 2150 return failure(); 2151 } 2152 2153 // Parse the optional attribute list. 2154 if (parser.parseOptionalAttrDict(result.attributes)) 2155 return failure(); 2156 return success(); 2157 } 2158 2159 void IfOp::print(OpAsmPrinter &p) { 2160 bool printBlockTerminators = false; 2161 2162 p << " " << getCond(); 2163 if (!getResults().empty()) { 2164 p << " -> (" << getResultTypes() << ")"; 2165 // Print yield explicitly if the op defines values. 2166 printBlockTerminators = true; 2167 } 2168 p << ' '; 2169 p.printRegion(getThenBranch(), 2170 /*printEntryBlockArgs=*/false, 2171 /*printBlockTerminators=*/printBlockTerminators); 2172 2173 // Print the 'else' regions if it exists and has a block. 2174 auto &elseRegion = getElseBranch(); 2175 if (!elseRegion.empty()) { 2176 p << " else "; 2177 p.printRegion(elseRegion, 2178 /*printEntryBlockArgs=*/false, 2179 /*printBlockTerminators=*/printBlockTerminators); 2180 } 2181 2182 p.printOptionalAttrDict((*this)->getAttrs()); 2183 } 2184 2185 LogicalResult ReverseOp::verify() { 2186 TensorType inputType = getInput1().getType(); 2187 TensorType outputType = getOutput().getType(); 2188 int32_t reverseAxis = getAxis(); 2189 2190 if (reverseAxis < 0) 2191 return emitOpError("expected non-negative reverse axis"); 2192 if (inputType.hasRank()) { 2193 int64_t inputRank = inputType.getRank(); 2194 // We allow for a special case where the input/output shape has rank 0 and 2195 // axis is also 0. 2196 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0)) 2197 return emitOpError("expect input tensor rank (") 2198 << inputRank << ") to be larger than reverse axis (" << reverseAxis 2199 << ")"; 2200 } 2201 if (outputType.hasRank()) { 2202 int64_t outputRank = outputType.getRank(); 2203 if (inputType.hasRank() && outputRank != inputType.getRank()) 2204 return emitOpError( 2205 "expect output tensor rank to be equal to input tensor rank"); 2206 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0)) 2207 return emitOpError("expect output tensor rank (") 2208 << outputRank << ") to be larger than reverse axis (" 2209 << reverseAxis << ")"; 2210 } 2211 return success(); 2212 } 2213 2214 // parse and print of WhileOp refer to the implementation of SCF dialect. 2215 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { 2216 SmallVector<OpAsmParser::Argument, 4> regionArgs; 2217 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; 2218 Region *cond = result.addRegion(); 2219 Region *body = result.addRegion(); 2220 2221 OptionalParseResult listResult = 2222 parser.parseOptionalAssignmentList(regionArgs, operands); 2223 if (listResult.has_value() && failed(listResult.value())) 2224 return failure(); 2225 2226 FunctionType functionType; 2227 SMLoc typeLoc = parser.getCurrentLocation(); 2228 if (failed(parser.parseColonType(functionType))) 2229 return failure(); 2230 2231 result.addTypes(functionType.getResults()); 2232 2233 if (functionType.getNumInputs() != operands.size()) { 2234 return parser.emitError(typeLoc) 2235 << "expected as many input types as operands " 2236 << "(expected " << operands.size() << " got " 2237 << functionType.getNumInputs() << ")"; 2238 } 2239 2240 // Resolve input operands. 2241 if (failed(parser.resolveOperands(operands, functionType.getInputs(), 2242 parser.getCurrentLocation(), 2243 result.operands))) 2244 return failure(); 2245 2246 // Propagate the types into the region arguments. 2247 for (size_t i = 0, e = regionArgs.size(); i != e; ++i) 2248 regionArgs[i].type = functionType.getInput(i); 2249 2250 return failure(parser.parseRegion(*cond, regionArgs) || 2251 parser.parseKeyword("do") || parser.parseRegion(*body) || 2252 parser.parseOptionalAttrDictWithKeyword(result.attributes)); 2253 } 2254 2255 static void printInitializationList(OpAsmPrinter &parser, 2256 Block::BlockArgListType blocksArgs, 2257 ValueRange initializers, 2258 StringRef prefix = "") { 2259 assert(blocksArgs.size() == initializers.size() && 2260 "expected same length of arguments and initializers"); 2261 if (initializers.empty()) 2262 return; 2263 2264 parser << prefix << '('; 2265 llvm::interleaveComma( 2266 llvm::zip(blocksArgs, initializers), parser, 2267 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); 2268 parser << ")"; 2269 } 2270 2271 void WhileOp::print(OpAsmPrinter &parser) { 2272 printInitializationList(parser, getCond().front().getArguments(), getInputs(), 2273 " "); 2274 parser << " : "; 2275 parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes()); 2276 parser << ' '; 2277 parser.printRegion(getCond(), /*printEntryBlockArgs=*/false); 2278 parser << " do "; 2279 parser.printRegion(getBody()); 2280 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 2281 } 2282 2283 //===----------------------------------------------------------------------===// 2284 // TOSA Shape and Shape Operators Helper functions. 2285 //===----------------------------------------------------------------------===// 2286 2287 bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) { 2288 return mlir::isa<tosa::shapeType>(t); 2289 } 2290 2291 LogicalResult 2292 mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError, 2293 int rank) { 2294 if (rank < 0) 2295 return emitError() << "invalid rank (must be >= 0): " << rank; 2296 return success(); 2297 } 2298 2299 LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { 2300 for (auto v : op->getOperands()) { 2301 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) { 2302 Operation *definingOp = v.getDefiningOp(); 2303 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) { 2304 return op->emitOpError("shape operand is not compile time resolvable"); 2305 } 2306 } 2307 } 2308 return success(); 2309 } 2310 2311 LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { 2312 for (auto type : op->getOperandTypes()) { 2313 if (!mlir::isa<mlir::tosa::shapeType>(type)) { 2314 return op->emitOpError("must have operands with tosa shape type"); 2315 } 2316 } 2317 for (auto type : op->getResultTypes()) { 2318 if (!mlir::isa<mlir::tosa::shapeType>(type)) { 2319 return op->emitOpError("must have result with tosa shape type"); 2320 } 2321 } 2322 return success(); 2323 } 2324 2325 LogicalResult 2326 OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { 2327 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || 2328 failed(verifyTosaShapeOperator(op))) 2329 return failure(); 2330 2331 // delegate function that returns rank of shape type 2332 auto getRank = [](const Type type) { 2333 return mlir::cast<mlir::tosa::shapeType>(type).getRank(); 2334 }; 2335 auto operandTypes = op->getOperandTypes(); 2336 auto resultTypes = op->getResultTypes(); 2337 2338 auto rank = getRank(*op->getOperandTypes().begin()); 2339 for (auto type : operandTypes) { 2340 if (getRank(type) != rank) { 2341 return op->emitOpError("operands don't have matching ranks"); 2342 } 2343 } 2344 for (auto type : resultTypes) { 2345 if (getRank(type) != rank) { 2346 return op->emitOpError("result shape has different rank than operands"); 2347 } 2348 } 2349 return success(); 2350 } 2351 2352 //===----------------------------------------------------------------------===// 2353 // TOSA Shape Operators verify functions. 2354 //===----------------------------------------------------------------------===// 2355 2356 LogicalResult tosa::ConstShapeOp::verify() { 2357 // check that number of elements in value attr equal to rank of result shape 2358 auto count = getValue().getNumElements(); 2359 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank(); 2360 if (!(count == rank || (count == 1 && rank == 0))) { 2361 return emitOpError("expect number of elements in attribute value (") 2362 << count << ") to be equal to the rank (" << rank 2363 << ") for the result shape type"; 2364 } 2365 return success(); 2366 } 2367 2368 //===----------------------------------------------------------------------===// 2369 // TOSA Attribute Definitions. 2370 //===----------------------------------------------------------------------===// 2371 2372 #define GET_ATTRDEF_CLASSES 2373 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" 2374 2375 //===----------------------------------------------------------------------===// 2376 // TOSA Type Definitions. 2377 //===----------------------------------------------------------------------===// 2378 #define GET_TYPEDEF_CLASSES 2379 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" 2380 2381 //===----------------------------------------------------------------------===// 2382 // TOSA Operator Definitions. 2383 //===----------------------------------------------------------------------===// 2384 2385 #define GET_OP_CLASSES 2386 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" 2387