xref: /llvm-project/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
17 #include "mlir/Dialect/Quant/IR/Quant.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Interfaces/InferTypeOpInterface.h"
28 #include "mlir/Transforms/InliningUtils.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
39 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 
47 namespace {
48 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Dialect Function Inliner Interface.
52 //===----------------------------------------------------------------------===//
53 struct TosaInlinerInterface : public DialectInlinerInterface {
54   using DialectInlinerInterface::DialectInlinerInterface;
55 
56   //===--------------------------------------------------------------------===//
57   // Analysis Hooks.
58   //===--------------------------------------------------------------------===//
59 
60   /// All operations can be inlined by default.
61   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
62                        IRMapping &map) const final {
63     return true;
64   }
65 
66   /// All regions with If and While parent operators can be inlined.
67   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68                        IRMapping &map) const final {
69     return (isa<tosa::IfOp>(dest->getParentOp()) ||
70             isa<tosa::WhileOp>(dest->getParentOp()));
71   }
72 };
73 
74 /// This class implements the bytecode interface for the Tosa dialect.
75 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
76   TosaDialectBytecodeInterface(Dialect *dialect)
77       : BytecodeDialectInterface(dialect) {}
78 
79   //===--------------------------------------------------------------------===//
80   // Attributes
81 
82   Attribute readAttribute(DialectBytecodeReader &reader) const override {
83     return ::readAttribute(getContext(), reader);
84   }
85 
86   LogicalResult writeAttribute(Attribute attr,
87                                DialectBytecodeWriter &writer) const override {
88     return ::writeAttribute(attr, writer);
89   }
90 
91   //===--------------------------------------------------------------------===//
92   // Types
93 
94   Type readType(DialectBytecodeReader &reader) const override {
95     return ::readType(getContext(), reader);
96   }
97 
98   LogicalResult writeType(Type type,
99                           DialectBytecodeWriter &writer) const override {
100     return ::writeType(type, writer);
101   }
102 
103   void writeVersion(DialectBytecodeWriter &writer) const final {
104     // TODO: Populate.
105   }
106 
107   std::unique_ptr<DialectVersion>
108   readVersion(DialectBytecodeReader &reader) const final {
109     // TODO: Populate
110     reader.emitError("Dialect does not support versioning");
111     return nullptr;
112   }
113 
114   LogicalResult upgradeFromVersion(Operation *topLevelOp,
115                                    const DialectVersion &version) const final {
116     return success();
117   }
118 };
119 
120 } // namespace
121 
122 //===----------------------------------------------------------------------===//
123 // TOSA control flow support.
124 //===----------------------------------------------------------------------===//
125 
126 /// Returns the while loop body.
127 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
128 
129 //===----------------------------------------------------------------------===//
130 // Tosa dialect initialization.
131 //===----------------------------------------------------------------------===//
132 
133 void TosaDialect::initialize() {
134   addTypes<
135 #define GET_TYPEDEF_LIST
136 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
137       >();
138   addOperations<
139 #define GET_OP_LIST
140 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
141       >();
142   addAttributes<
143 #define GET_ATTRDEF_LIST
144 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
145       >();
146   addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
147   declarePromisedInterfaces<
148       mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
149       ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
150       LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
151       LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
152       BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
153       NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
154       GreaterEqualOp, MatMulOp>();
155 }
156 
157 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
158                                             Type type, Location loc) {
159   // Tosa dialect constants only support ElementsAttr unlike standard dialect
160   // constant which supports all attributes.
161   if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
162     return builder.create<tosa::ConstShapeOp>(
163         loc, type, llvm::cast<DenseIntElementsAttr>(value));
164   }
165   if (llvm::isa<ElementsAttr>(value))
166     return builder.create<tosa::ConstOp>(loc, type,
167                                          llvm::cast<ElementsAttr>(value));
168   return nullptr;
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Parsers and printers
173 //===----------------------------------------------------------------------===//
174 
175 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
176                                         Attribute &attr) {
177   if (succeeded(parser.parseOptionalEqual())) {
178     if (failed(parser.parseAttribute(attr))) {
179       return parser.emitError(parser.getCurrentLocation())
180              << "expected attribute";
181     }
182     if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
183       typeAttr = TypeAttr::get(typedAttr.getType());
184     }
185     return success();
186   }
187 
188   Type type;
189   if (failed(parser.parseColonType(type))) {
190     return parser.emitError(parser.getCurrentLocation()) << "expected type";
191   }
192   typeAttr = TypeAttr::get(type);
193 
194   return success();
195 }
196 
197 void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
198                                  Attribute attr) {
199   bool needsSpace = false;
200   auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
201   if (!typedAttr || typedAttr.getType() != type.getValue()) {
202     p << ": ";
203     p.printAttribute(type);
204     needsSpace = true; // subsequent attr value needs a space separator
205   }
206   if (attr) {
207     if (needsSpace)
208       p << ' ';
209     p << "= ";
210     p.printAttribute(attr);
211   }
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // TOSA Operator Verifiers.
216 //===----------------------------------------------------------------------===//
217 
218 template <typename T>
219 static LogicalResult verifyConvOp(T op) {
220   // All TOSA conv ops have an input() and weight().
221   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
222 
223   RankedTensorType weightType;
224   if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
225     weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
226   else
227     weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
228 
229   // Must be ranked tensor types
230   if (!inputType) {
231     op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
232     return failure();
233   }
234   if (!weightType) {
235     if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
236       op.emitOpError("expect a ranked tensor for filter, got ")
237           << op.getFilter();
238     } else {
239       op.emitOpError("expect a ranked tensor for weight, got ")
240           << op.getWeight();
241     }
242     return failure();
243   }
244 
245   auto inputEType = inputType.getElementType();
246   auto weightEType = weightType.getElementType();
247 
248   bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
249   bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
250 
251   // Either both must be quantized or both unquantized.
252   if (inputIsQuant != weightIsQuant) {
253     op.emitOpError(
254         "expect both input and weight to be float or not together, got ")
255         << inputEType << " and " << weightEType;
256     return failure();
257   }
258 
259   // Quantized type must have constructed the quantizationattr, and unquantized
260   // types should not have a quantizationattr.
261   if ((inputIsQuant && !op.getQuantizationInfo()) ||
262       (!inputIsQuant && op.getQuantizationInfo())) {
263     op.emitOpError("quantizationattr is required for quantized type, and not "
264                    "allowed for float type");
265     return failure();
266   }
267   return success();
268 }
269 
270 LogicalResult tosa::ConstOp::verify() {
271 
272   auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
273   auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
274 
275   if (!attrType || !outputType) {
276     emitOpError("expected tensors for attr/result type");
277     return failure();
278   }
279 
280   if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
281           outputType.getElementType())) {
282     if (result.getStorageType() == attrType.getElementType())
283       return success();
284   }
285 
286   if (attrType.getElementType() != outputType.getElementType()) {
287     emitOpError("expected same attr/result element types");
288     return failure();
289   }
290 
291   return success();
292 }
293 
294 template <typename T>
295 static LogicalResult verifyConvOpModes(T op) {
296   auto inputEType =
297       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
298 
299   if (auto quantType =
300           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
301     inputEType = quantType.getStorageType();
302 
303   auto accType = op.getAccType();
304   if (inputEType.isInteger(8) && !accType.isInteger(32))
305     return op.emitOpError("accumulator type for i8 tensor is not i32");
306 
307   if (inputEType.isInteger(16) && !accType.isInteger(48))
308     return op.emitOpError("accumulator type for i16 tensor is not i48");
309 
310   if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
311     return op.emitOpError("accumulator type for f8 tensor is not f16");
312 
313   if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
314     return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
315 
316   if (inputEType.isBF16() && !accType.isF32())
317     return op.emitOpError("accumulator type for bf16 tensor is not f32");
318 
319   if (inputEType.isF32() && !accType.isF32())
320     return op.emitOpError("accumulator type for f32 tensor is not f32");
321 
322   return success();
323 }
324 
325 LogicalResult tosa::ArgMaxOp::verify() {
326   // Ensure output is of 32-bit integer
327   const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
328   if (!resultETy.isIntOrIndex())
329     return emitOpError("result tensor is not of integer type");
330 
331   // Ensure axis is within the tensor rank
332   const auto inputType = llvm::cast<ShapedType>(getInput().getType());
333   const int64_t axis = getAxisAttr().getInt();
334   if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
335     return emitOpError("specified axis is outside the rank of the tensor");
336 
337   return success();
338 }
339 
340 LogicalResult tosa::AvgPool2dOp::verify() {
341   auto inputType = llvm::cast<ShapedType>(getInput().getType());
342 
343   auto inputETy = inputType.getElementType();
344   auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
345 
346   if (auto quantType =
347           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
348     inputETy = quantType.getStorageType();
349 
350   if (auto quantType =
351           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
352     resultETy = quantType.getStorageType();
353 
354   auto accType = getAccType();
355   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
356     return emitOpError("accumulator type for integer tensor is not i32");
357 
358   if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
359     return emitOpError("accumulator type for f16 tensor is not f16/f32");
360 
361   if (inputETy.isBF16() && !accType.isF32())
362     return emitOpError("accumulator type for bf16 tensor is not f32");
363 
364   if (inputETy.isF32() && !accType.isF32())
365     return emitOpError("accumulator type for f32 tensor is not f32");
366 
367   if ((inputETy.isF32() && resultETy.isF32()) ||
368       (inputETy.isF16() && resultETy.isF16()) ||
369       (inputETy.isBF16() && resultETy.isBF16()) ||
370       (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
371       (inputETy.isInteger(16) && resultETy.isInteger(16)))
372     return success();
373 
374   return emitOpError("input/output element types are incompatible.");
375 }
376 
377 LogicalResult tosa::ClampOp::verify() {
378   mlir::Type inputETy =
379       llvm::cast<ShapedType>(getInput().getType()).getElementType();
380   if (auto quantType =
381           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
382     inputETy = quantType.getStorageType();
383   }
384   mlir::Type maxFpType = getMaxFpAttr().getType();
385   mlir::Type minFpType = getMinFpAttr().getType();
386   mlir::Type outputETy =
387       llvm::cast<ShapedType>(getOutput().getType()).getElementType();
388   if (auto quantType =
389           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
390     outputETy = quantType.getStorageType();
391   }
392   unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
393 
394   if (inputETy != outputETy)
395     return emitOpError("input/output element types are incompatible.");
396 
397   // If input datatype is float, check that the two min/max_fp attributes
398   // share the same type and that their type is either the same of the input's
399   // datatype, or a float type whose bitwidth > input datatype bitwidth.
400   if (!inputETy.isInteger(dataTypeBitWidth)) {
401     if (((maxFpType != minFpType) ||
402          (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
403                                        inputETy.getIntOrFloatBitWidth())))
404       return emitOpError("min/max attributes types are incompatible with "
405                          "input/output element types.");
406   }
407 
408   return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // TOSA Operator Quantization Builders.
413 //===----------------------------------------------------------------------===//
414 
415 /// This builder is called on all convolution operators except TransposeConv,
416 /// which has specialized output shape semantics. The builder also defines the
417 /// bitwidth of the output given the bit width of the input & weight content.
418 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
419                                      Type outputType, Value input, Value weight,
420                                      Value bias, DenseI64ArrayAttr pad,
421                                      DenseI64ArrayAttr stride,
422                                      DenseI64ArrayAttr dilation,
423                                      TypeAttr accType) {
424 
425   result.addOperands({input, weight, bias});
426   result.addAttribute("pad", pad);
427   result.addAttribute("stride", stride);
428   result.addAttribute("dilation", dilation);
429   result.addAttribute("acc_type", accType);
430 
431   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
432   if (quantAttr) {
433     result.addAttribute("quantization_info", quantAttr);
434     result.addTypes(
435         buildConvOpResultTypeInfo(builder, outputType, input, weight));
436   } else {
437     result.addTypes(outputType);
438   }
439 }
440 
441 /// Handles tosa.transpose_conv2d which has outpad and output shape
442 /// attributes.
443 static void buildTransConvOpWithQuantInfo(
444     OpBuilder &builder, OperationState &result, Type outputType, Value input,
445     Value weight, Value bias, DenseI64ArrayAttr outpad,
446     DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
447   result.addOperands({input, weight, bias});
448   result.addAttribute("out_pad", outpad);
449   result.addAttribute("stride", stride);
450   result.addAttribute("out_shape", outputShape);
451   result.addAttribute("acc_type", accType);
452   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
453 
454   if (quantAttr) {
455     result.addAttribute("quantization_info", quantAttr);
456     result.addTypes(
457         buildConvOpResultTypeInfo(builder, outputType, input, weight));
458   } else {
459     result.addTypes(outputType);
460   }
461 }
462 
463 /// The tosa.fully_connected op has its own builder as it does not have
464 /// strides/dilation/padding.
465 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
466                                    Type outputType, Value input, Value weight,
467                                    Value bias) {
468 
469   result.addOperands({input, weight, bias});
470   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
471   if (quantAttr) {
472     result.addAttribute("quantization_info", quantAttr);
473     result.addTypes(
474         buildConvOpResultTypeInfo(builder, outputType, input, weight));
475   } else {
476     result.addTypes(outputType);
477   }
478 }
479 
480 /// The tosa.matmul op is also intended to be generated where a
481 /// fully_connected op must be constructed where the weight is not a constant.
482 /// In this case, the fully_connected op must be expressed using matmul.
483 /// TODO: Add link to the leglization document explaining this.
484 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
485                                        OperationState &result, Type outputType,
486                                        Value a, Value b) {
487   result.addOperands({a, b});
488   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
489 
490   if (quantAttr) {
491     result.addAttribute("quantization_info", quantAttr);
492 
493     auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
494     assert(inputType && "Input must be a shaped tensor type!");
495 
496     auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
497         inputType.getElementType());
498     assert(inputQType && "Tensor must have quantized datatype!");
499 
500     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
501 
502     auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
503     assert(outputShapedType && "Output must be a shaped type");
504 
505     IntegerType accElementType;
506     if (inputBits == 16)
507       accElementType = builder.getIntegerType(48);
508     else
509       accElementType = builder.getI32Type();
510     auto accType = outputShapedType.clone(accElementType);
511     result.addTypes(accType);
512   } else {
513     result.addTypes(outputType);
514   }
515 }
516 
517 /// Both the tosa.avg_pool2d and unary ops use the same
518 /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
519 /// has additional parameters not part of the unary ops.
520 static void
521 buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
522                               Type outputType, Value input,
523                               DenseArrayAttr kernel, DenseArrayAttr stride,
524                               DenseArrayAttr pad, TypeAttr accType) {
525   result.addOperands(input);
526   result.addAttribute("kernel", kernel);
527   result.addAttribute("stride", stride);
528   result.addAttribute("pad", pad);
529   result.addAttribute("acc_type", accType);
530   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
531   if (quantAttr)
532     result.addAttribute("quantization_info", quantAttr);
533   result.types.push_back(outputType);
534 }
535 
536 /// This builder is called on single-parameter unary operators that have scale
537 /// relationship between their input and output, expressed by the
538 /// UnaryOpQuantizationAttr.
539 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
540                                       OperationState &result, Type outputType,
541                                       Value input) {
542   result.addOperands(input);
543   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
544   if (quantAttr)
545     result.addAttribute("quantization_info", quantAttr);
546   result.types.push_back(outputType);
547 }
548 
549 /// This builder is called on TOSA pad operator that needs to create its own
550 /// OptionalAttr quantization_attr parameter to scale the padding values
551 /// correctly. No pad_const is interpreted as zero-padding.
552 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
553                                     Type outputType, Value input,
554                                     Value paddings) {
555   result.addOperands({input, paddings});
556   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
557   if (quantAttr)
558     result.addAttribute("quantization_info", quantAttr);
559   result.types.push_back(outputType);
560 }
561 
562 /// This builder is called on TOSA pad operator when an explicit pad_const
563 /// value is passed in. It also optionally constructs quantization_attr.
564 static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
565                                                  OperationState &result,
566                                                  Type outputType, Value input,
567                                                  Value paddings,
568                                                  Value padConst) {
569   result.addOperands({input, paddings, padConst});
570   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
571   if (quantAttr)
572     result.addAttribute("quantization_info", quantAttr);
573   result.types.push_back(outputType);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // TOSA Operator Return Type Inference.
578 //===----------------------------------------------------------------------===//
579 
580 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
581                                            SmallVector<int64_t> &outShape) {
582   int64_t outRank = 0;
583   for (int i = 0, e = operands.size(); i != e; ++i) {
584     auto shape = operands.getShape(i);
585     if (!shape.hasRank()) {
586       // TODO(jennik): Update function to have better case handling for
587       // invalid operands and for ranked tensors.
588       return failure();
589     }
590     outRank = std::max<int64_t>(outRank, shape.getRank());
591   }
592 
593   outShape.resize(outRank, 1);
594 
595   for (int i = 0, e = operands.size(); i != e; ++i) {
596     auto shape = operands.getShape(i);
597     auto rankDiff = outShape.size() - shape.getRank();
598 
599     for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
600       auto dim1 = outShape[i + rankDiff];
601       auto dim2 = shape.getDimSize(i);
602       auto resolvedDim = dim1;
603 
604       if (dim1 == 1) {
605         resolvedDim = dim2;
606       } else if (dim2 == 1) {
607         resolvedDim = dim1;
608       } else if (dim1 != dim2) {
609         return failure();
610       }
611       outShape[i + rankDiff] = resolvedDim;
612     }
613   }
614 
615   return success();
616 }
617 
618 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
619     MLIRContext *context, ::std::optional<Location> location,
620     ArgMaxOp::Adaptor adaptor,
621     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
622   ShapeAdaptor inputShape(adaptor.getInput().getType());
623   IntegerAttr axis = adaptor.getProperties().axis;
624   int32_t axisVal = axis.getValue().getSExtValue();
625 
626   if (!inputShape.hasRank()) {
627     inferredReturnShapes.push_back(ShapedTypeComponents());
628     return success();
629   }
630 
631   SmallVector<int64_t> outShape;
632   outShape.reserve(inputShape.getRank() - 1);
633   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
634     if (i == axisVal)
635       continue;
636     outShape.push_back(inputShape.getDimSize(i));
637   }
638 
639   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
640   return success();
641 }
642 
643 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
644     MLIRContext *context, ::std::optional<Location> location,
645     RFFT2dOp::Adaptor adaptor,
646     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
647   ShapeAdaptor inputShape(adaptor.getInput().getType());
648 
649   if (!inputShape.hasRank())
650     return failure();
651 
652   llvm::SmallVector<int64_t> outputShape;
653   outputShape.resize(3, ShapedType::kDynamic);
654   outputShape[0] = inputShape.getDimSize(0);
655   outputShape[1] = inputShape.getDimSize(1);
656   int64_t inWidth = inputShape.getDimSize(2);
657 
658   // Note that we can support this calculation symbolically
659   // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
660   if (inWidth != ShapedType::kDynamic)
661     outputShape[2] = inWidth / 2 + 1;
662 
663   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
664   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
665 
666   return success();
667 }
668 
669 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
670     MLIRContext *context, ::std::optional<Location> location,
671     FFT2dOp::Adaptor adaptor,
672     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
673   inferredReturnShapes.push_back(
674       ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
675   inferredReturnShapes.push_back(
676       ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
677   return success();
678 }
679 
680 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
681     MLIRContext *context, ::std::optional<Location> location,
682     ConcatOp::Adaptor adaptor,
683     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
684   // Infer all dimension sizes by reducing based on inputs.
685   const Properties &prop = adaptor.getProperties();
686   int32_t axis = prop.axis.getValue().getSExtValue();
687   llvm::SmallVector<int64_t> outputShape;
688   bool hasRankedInput = false;
689   for (auto operand : adaptor.getOperands()) {
690     ShapeAdaptor operandShape(operand.getType());
691     if (!operandShape.hasRank())
692       continue;
693 
694     // Copy the Operand's rank.
695     if (!hasRankedInput)
696       outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
697 
698     // Copy shapes until the dim is non-dynamic.
699     for (int i = 0, s = operandShape.getRank(); i < s; i++) {
700       if (i == axis || operandShape.isDynamicDim(i))
701         continue;
702       if (outputShape[i] == ShapedType::kDynamic)
703         outputShape[i] = operandShape.getDimSize(i);
704       if (outputShape[i] != operandShape.getDimSize(i))
705         return emitOptionalError(location,
706                                  "Cannot concat tensors with different sizes"
707                                  " on the non-axis dimension ",
708                                  i);
709     }
710 
711     hasRankedInput = true;
712   }
713   Type inputType =
714       llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
715   if (!hasRankedInput) {
716     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
717     return success();
718   }
719 
720   // Determine the dimension size along the concatenation axis.
721   int64_t concatDimSize = 0;
722   for (auto operand : adaptor.getOperands()) {
723     ShapeAdaptor operandShape(operand.getType());
724 
725     // We need to know the length of the concatenation axis of all inputs to
726     // determine the dimension size of the output shape.
727     if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
728       concatDimSize = ShapedType::kDynamic;
729       break;
730     }
731 
732     concatDimSize += operandShape.getDimSize(axis);
733   }
734 
735   outputShape[axis] = concatDimSize;
736 
737   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
738   return success();
739 }
740 
741 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
742     MLIRContext *context, ::std::optional<Location> location,
743     ValueShapeRange operands, DictionaryAttr attributes,
744     OpaqueProperties properties, RegionRange regions,
745     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
746   auto elementType = IntegerType::get(context, /*width=*/1);
747 
748   llvm::SmallVector<int64_t> outShape;
749   if (resolveBroadcastShape(operands, outShape).failed()) {
750     inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
751     return success();
752   }
753 
754   inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
755   return success();
756 }
757 
758 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
759   if (l.size() != r.size() || l.size() != 1)
760     return false;
761   return succeeded(verifyCompatibleShape(l[0], r[0]));
762 }
763 
764 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
765     MLIRContext *context, ::std::optional<Location> location,
766     FullyConnectedOp::Adaptor adaptor,
767     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
768   ShapeAdaptor inputShape(adaptor.getInput().getType());
769   ShapeAdaptor weightShape(adaptor.getWeight().getType());
770   ShapeAdaptor biasShape(adaptor.getBias().getType());
771 
772   // All shapes are dynamic.
773   SmallVector<int64_t> outShape;
774   outShape.resize(2, ShapedType::kDynamic);
775 
776   if (inputShape.hasRank()) {
777     outShape[0] = inputShape.getDimSize(0);
778   }
779 
780   if (weightShape.hasRank()) {
781     outShape[1] = weightShape.getDimSize(0);
782   }
783 
784   if (biasShape.hasRank()) {
785     outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
786                                                       : outShape[1];
787   }
788 
789   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
790   return success();
791 }
792 
793 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
794 
795 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
796     MLIRContext *context, ::std::optional<Location> location,
797     MatMulOp::Adaptor adaptor,
798     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
799   ShapeAdaptor lhsShape(adaptor.getA().getType());
800   ShapeAdaptor rhsShape(adaptor.getB().getType());
801 
802   // All shapes are dynamic.
803   SmallVector<int64_t> outShape;
804   outShape.resize(3, ShapedType::kDynamic);
805 
806   if (lhsShape.hasRank()) {
807     outShape[0] = lhsShape.getDimSize(0);
808     outShape[1] = lhsShape.getDimSize(1);
809   }
810 
811   if (rhsShape.hasRank()) {
812     outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
813                                                       : outShape[0];
814     outShape[2] = rhsShape.getDimSize(2);
815   }
816 
817   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
818   return success();
819 }
820 
821 LogicalResult tosa::PadOp::inferReturnTypeComponents(
822     MLIRContext *context, ::std::optional<Location> location,
823     PadOp::Adaptor adaptor,
824     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
825   ShapeAdaptor inputShape(adaptor.getInput1().getType());
826   auto paddingRank =
827       cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
828   SmallVector<int64_t> outputShape;
829 
830   // If the input rank is unknown, we can infer the output rank using the
831   // padding shape's rank divided by 2.
832   if (!inputShape.hasRank()) {
833     outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
834     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
835     return success();
836   }
837 
838   SmallVector<int64_t> paddingValues;
839   // If the paddings value is not a constant, all dimensions must be dynamic.
840   if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
841                                 paddingValues)) {
842     outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
843     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
844     return success();
845   }
846 
847   outputShape.reserve(inputShape.getRank());
848   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
849     if (inputShape.isDynamicDim(i)) {
850       outputShape.push_back(ShapedType::kDynamic);
851       continue;
852     }
853     auto padFront = paddingValues[i * 2];
854     auto padBack = paddingValues[i * 2 + 1];
855     if (padFront < 0 || padBack < 0) {
856       // if either padding for dim i is -1, output dim is unknown
857       outputShape.push_back(ShapedType::kDynamic);
858       continue;
859     }
860 
861     outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
862   }
863 
864   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
865   return success();
866 }
867 
868 LogicalResult tosa::PadOp::verify() {
869   RankedTensorType inputType = getInput1().getType();
870   RankedTensorType outputType = getOutput().getType();
871   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
872 
873   if (inputType.getRank() != outputType.getRank())
874     return emitOpError() << "expect same input and output tensor rank.";
875 
876   if (paddingRank != inputType.getRank() * 2)
877     return emitOpError() << "expected padding tensor dim 0 to have size "
878                          << inputType.getRank() * 2
879                          << " (2*rank(shape1)) but got size " << paddingRank;
880 
881   return success();
882 }
883 
884 static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
885   return to_vector(llvm::map_range(shape, [](int64_t dim) {
886     return dim == -1 ? ShapedType::kDynamic : dim;
887   }));
888 }
889 
890 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
891     MLIRContext *context, ::std::optional<Location> location,
892     SliceOp::Adaptor adaptor,
893     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
894 
895   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
896   SmallVector<int64_t> start;
897   SmallVector<int64_t> size;
898 
899   if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
900       !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
901     auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
902     SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
903     inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
904     return success();
905   }
906 
907   // if size[i] is -1, all remaining elements in dimension i are included
908   // in the slice, similar to TF.
909   ShapeAdaptor inputShape(adaptor.getInput1().getType());
910   // initialize outputShape to all unknown
911   SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
912   if (inputShape.hasRank()) {
913     for (size_t i = 0; i < size.size(); i++) {
914       if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
915           (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
916            start[i] < inputShape.getDimSize(i))) {
917         // size[i] is not 0 and not < -1, and start[i] is in valid range
918         if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
919           // input shape has unknown dim[i] - only valid if size[i] > 0
920           if (size[i] > 0) {
921             outputShape[i] = size[i];
922           }
923         } else {
924           // input shape has known dim[i]
925           if (size[i] == -1) {
926             outputShape[i] = inputShape.getDimSize(i) - start[i];
927           } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
928             // start[i] + size[i] is within bound of input shape's dim[i]
929             outputShape[i] = size[i];
930           }
931         }
932       }
933     }
934   } else {
935     outputShape = convertToMlirShape(size);
936   }
937   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
938   return success();
939 }
940 
941 LogicalResult tosa::SliceOp::verify() {
942   auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
943   if (!inputType)
944     return success();
945 
946   auto startShapeRank =
947       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
948   if (inputType.getRank() != startShapeRank)
949     return emitOpError(
950         "length of start attribute is not equal rank of input shape");
951 
952   auto sizeShapeRank =
953       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
954   if (inputType.getRank() != sizeShapeRank)
955     return emitOpError(
956         "length of size attribute is not equal rank of input shape");
957 
958   return success();
959 }
960 
961 LogicalResult tosa::MulOp::verify() {
962   auto resElemType = getElementTypeOrSelf(getOutput());
963 
964   // Verify if the element type among operands and result match tosa
965   // specification.
966   if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
967     IntegerType lhsIntType =
968         cast<IntegerType>(getElementTypeOrSelf(getInput1()));
969     IntegerType rhsIntType =
970         cast<IntegerType>(getElementTypeOrSelf(getInput2()));
971     if (lhsIntType != rhsIntType)
972       return emitOpError("requires the same element type for all operands");
973 
974     // Though the spec requires the element type of result to be i32, a more
975     // relaxed way is provided at dialect level for easier cooperating with
976     // other dialects.
977     if (lhsIntType.getWidth() > resIntType.getWidth())
978       return emitOpError("invalid data type size for operands or result");
979 
980   } else {
981     // For other supported type, the spec requires requires the same element
982     // type for all operands (excludes `shift` operand) and results.
983     for (int i = 0; i < 2; ++i) {
984       if (getElementTypeOrSelf(getOperand(i)) != resElemType)
985         return emitOpError(
986             "requires the same element type for all operands and results");
987     }
988   }
989 
990   // Verify the op has same ranks for all main operands (excludes extra operands
991   // such as shift of mul op, so this is the only difference with the built-in
992   // `SameOperandsAndResultRank` trait) and results types, if known.
993 
994   // delegate function that returns true if type is a shaped type with known
995   // rank
996   auto hasRank = [](const Type type) {
997     if (auto shaped_type = dyn_cast<ShapedType>(type))
998       return shaped_type.hasRank();
999 
1000     return false;
1001   };
1002 
1003   auto rankedOperandTypes =
1004       llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1005 
1006   auto rankedResultTypes =
1007       llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1008 
1009   // If all operands and results are unranked, then no further verification.
1010   if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1011     return success();
1012 
1013   // delegate function that returns rank of shaped type with known rank
1014   auto getRank = [](const Type type) {
1015     return cast<ShapedType>(type).getRank();
1016   };
1017 
1018   auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1019                                           : getRank(*rankedResultTypes.begin());
1020 
1021   for (size_t i = 0; i < 2; ++i) {
1022     if (rank != getRank(rankedOperandTypes[i])) {
1023       return emitOpError("operands don't have matching ranks");
1024     }
1025   }
1026 
1027   for (const auto type : rankedResultTypes) {
1028     if (rank != getRank(type)) {
1029       return emitOpError("result type has different rank than operands");
1030     }
1031   }
1032 
1033   return success();
1034 }
1035 
1036 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1037     MLIRContext *context, ::std::optional<Location> location,
1038     TableOp::Adaptor adaptor,
1039     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1040   ShapeAdaptor inputShape(adaptor.getInput1().getType());
1041 
1042   if (!inputShape.hasRank()) {
1043     inferredReturnShapes.push_back(ShapedTypeComponents());
1044     return success();
1045   }
1046 
1047   inferredReturnShapes.resize(1);
1048   inputShape.getDims(inferredReturnShapes[0]);
1049   return success();
1050 }
1051 
1052 LogicalResult tosa::TableOp::verify() {
1053   TensorType inputType = getInput1().getType();
1054   TensorType outputType = getOutput().getType();
1055 
1056   if (inputType.hasRank() && outputType.hasRank() &&
1057       inputType.getRank() != outputType.getRank())
1058     return emitOpError()
1059            << "expected input tensor rank to equal result tensor rank";
1060 
1061   auto inputDims = inputType.getShape();
1062   auto outputDims = outputType.getShape();
1063   for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1064     int64_t dim = it.index();
1065     auto [inputDim, outputDim] = it.value();
1066     if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1067       return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1068                            << " doesn't match dim(input, " << dim
1069                            << ") = " << inputDim;
1070     }
1071   }
1072   return success();
1073 }
1074 
1075 LogicalResult
1076 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1077   // Multiples must be constants.
1078   DenseIntElementsAttr multiplesAttr;
1079   if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1080     return failure();
1081   multiples = llvm::to_vector(
1082       llvm::map_range(multiplesAttr.getValues<APInt>(),
1083                       [](const APInt &val) { return val.getSExtValue(); }));
1084   return success();
1085 }
1086 
1087 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1088     MLIRContext *context, ::std::optional<Location> location,
1089     TileOp::Adaptor adaptor,
1090     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1091   DenseIntElementsAttr multiplesAttr;
1092   if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1093     return failure();
1094 
1095   SmallVector<int64_t> multiples = llvm::to_vector(
1096       llvm::map_range(multiplesAttr.getValues<APInt>(),
1097                       [](const APInt &val) { return val.getSExtValue(); }));
1098 
1099   ShapeAdaptor inputShape(adaptor.getInput1().getType());
1100   SmallVector<int64_t> outputShape;
1101   if (!inputShape.hasRank()) {
1102     outputShape.resize(multiples.size(), ShapedType::kDynamic);
1103     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1104     return success();
1105   } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1106     return failure();
1107 
1108   // Any non dynamic dimension can be multiplied to a known size.
1109   outputShape.reserve(multiples.size());
1110   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1111     int64_t dim = inputShape.getDimSize(i);
1112     if (dim != ShapedType::kDynamic)
1113       dim *= multiples[i];
1114     outputShape.push_back(dim);
1115   }
1116 
1117   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1118   return success();
1119 }
1120 
1121 LogicalResult tosa::TileOp::verify() {
1122   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1123   ShapedType outputType = llvm::cast<ShapedType>(getType());
1124 
1125   shapeType multiplesType =
1126       llvm::cast<tosa::shapeType>(getMultiples().getType());
1127 
1128   auto multiplesRank = multiplesType.getRank();
1129 
1130   if (inputType.hasRank()) {
1131     if (inputType.getRank() != multiplesRank)
1132       return emitOpError("expect 'multiples' to have rank ")
1133              << inputType.getRank() << " but got " << multiplesRank << ".";
1134     if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1135       return emitOpError("expect same input and output tensor rank.");
1136   } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1137     return emitOpError("expect 'multiples' array to have length ")
1138            << outputType.getRank() << " but got " << multiplesRank << ".";
1139 
1140   SmallVector<int64_t> multiples;
1141   if (getConstantMultiples(multiples).succeeded() &&
1142       llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1143     return emitOpError(
1144         "expect element of 'multiples' to be positive integer or -1.");
1145 
1146   return success();
1147 }
1148 
1149 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1150   if (l.size() != r.size() || l.size() != 1)
1151     return false;
1152   return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
1153 }
1154 
1155 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1156     MLIRContext *context, ::std::optional<Location> location,
1157     ReshapeOp::Adaptor adaptor,
1158     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1159   ShapeAdaptor inputShape(adaptor.getInput1().getType());
1160   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1161   llvm::SmallVector<int64_t> newShapeValue =
1162       convertToMlirShape(adaptor.getNewShape());
1163 
1164   // We cannot infer from the total number of elements so we must take the
1165   // shape attribute as exact.
1166   if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1167     inferredReturnShapes.push_back(
1168         ShapedTypeComponents(newShapeValue, inputType));
1169     return success();
1170   }
1171 
1172   // Determine the number of elements covered by the slice of all static
1173   // dimensions. This allows us to infer the length of the remaining dynamic
1174   // dimension.
1175   int64_t numElements = inputShape.getNumElements();
1176   int64_t staticMul = 1;
1177   for (auto val : newShapeValue) {
1178     if (!ShapedType::isDynamic(val)) {
1179       staticMul *= val;
1180     }
1181   }
1182 
1183   // Determine the length of the dynamic dimension.
1184   for (auto &val : newShapeValue) {
1185     if (ShapedType::isDynamic(val))
1186       val = numElements / staticMul;
1187   }
1188 
1189   inferredReturnShapes.push_back(
1190       ShapedTypeComponents(newShapeValue, inputType));
1191   return success();
1192 }
1193 
1194 llvm::LogicalResult tosa::ReshapeOp::verify() {
1195   TensorType inputType = getInput1().getType();
1196   RankedTensorType outputType = getType();
1197 
1198   if ((int64_t)getNewShape().size() != outputType.getRank())
1199     return emitOpError() << "new shape does not match result rank";
1200 
1201   for (auto [newShapeDim, outputShapeDim] :
1202        zip(getNewShape(), outputType.getShape())) {
1203     if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1204         newShapeDim != outputShapeDim)
1205       return emitOpError() << "new shape is inconsistent with result shape";
1206 
1207     if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1208       return emitOpError() << "new shape has invalid tensor dimension size "
1209                            << newShapeDim;
1210   }
1211 
1212   if (inputType.hasStaticShape()) {
1213     int64_t inputElementsNum = inputType.getNumElements();
1214     if (outputType.hasStaticShape()) {
1215       int64_t outputElementsNum = outputType.getNumElements();
1216       if (inputElementsNum != outputElementsNum) {
1217         return emitOpError() << "cannot reshape " << inputElementsNum
1218                              << " elements into " << outputElementsNum;
1219       }
1220     }
1221 
1222     int64_t newShapeElementsNum = std::accumulate(
1223         getNewShape().begin(), getNewShape().end(), 1LL,
1224         [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1225     bool isStaticNewShape =
1226         llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1227     if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1228         (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1229       return emitOpError() << "cannot reshape " << inputElementsNum
1230                            << " elements into " << newShapeElementsNum;
1231     }
1232   }
1233 
1234   int missingDims = llvm::count(getNewShape(), -1);
1235   if (missingDims > 1)
1236     return emitOpError() << "expected at most one target dimension to be -1";
1237 
1238   return mlir::success();
1239 }
1240 
1241 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1242   // Perms must be constants.
1243   DenseIntElementsAttr permsAttr;
1244   if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1245     return failure();
1246 
1247   perms.clear();
1248   for (auto v : permsAttr.getValues<APInt>())
1249     perms.push_back(v.getSExtValue());
1250 
1251   return success();
1252 }
1253 
1254 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1255     MLIRContext *context, ::std::optional<Location> location,
1256     TransposeOp::Adaptor adaptor,
1257     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1258   ShapeAdaptor inputShape(adaptor.getInput1().getType());
1259   ShapeAdaptor permsShape(adaptor.getPerms().getType());
1260 
1261   // We cannot infer anything from a rank-0 "permutation" tensor.
1262   if (permsShape.hasRank() && permsShape.getRank() == 0)
1263     return failure();
1264 
1265   // If input rank and permutation length is unknown, the output rank is
1266   // unknown.
1267   if (!inputShape.hasRank() || !permsShape.hasRank() ||
1268       permsShape.isDynamicDim(0)) {
1269     inferredReturnShapes.push_back(ShapedTypeComponents());
1270     return success();
1271   }
1272 
1273   // This would imply the number of permutations does not match the rank of
1274   // the input which is illegal.
1275   if (permsShape.getDimSize(0) != inputShape.getRank()) {
1276     return failure();
1277   }
1278 
1279   SmallVector<int64_t> outputShape;
1280   // Rank-0 means no permutations matter.
1281   if (inputShape.getRank() == 0) {
1282     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1283     return success();
1284   }
1285 
1286   // Check whether the input dimensions are all the same.
1287   bool allTheSame = true;
1288   for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1289     if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1290       allTheSame = false;
1291       break;
1292     }
1293   }
1294 
1295   // If all of the input dimensions are the same we don't care about the
1296   // permutation.
1297   if (allTheSame) {
1298     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1299     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1300     return success();
1301   }
1302 
1303   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1304   // If the permuations are a constant we can directly determine the output
1305   // shape.
1306   DenseIntElementsAttr attr;
1307   if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1308       attr.getType().getRank() == 1) {
1309     ShapeAdaptor permShape = attr;
1310     // Constant permutation must be the same length as the input rank.
1311     if (inputShape.getRank() != permShape.getRank())
1312       return emitOptionalError(location,
1313                                "constant permutation must be the same length"
1314                                " as the input rank");
1315 
1316     // Constant permutation values must be within the input rank.
1317     for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1318       if (inputShape.getRank() <= permShape.getDimSize(i))
1319         return failure();
1320     }
1321 
1322     outputShape.reserve(inputShape.getRank());
1323     for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1324       outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1325     }
1326   }
1327 
1328   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1329   return success();
1330 }
1331 
1332 LogicalResult tosa::TransposeOp::verify() {
1333   TensorType inputType = getInput1().getType();
1334   TensorType permType = getPerms().getType();
1335   TensorType outputType = getOutput().getType();
1336 
1337   if (permType.hasRank() && permType.getRank() != 1)
1338     return emitOpError()
1339            << "expected permutation tensor to be rank 1 but got rank "
1340            << permType.getRank();
1341   if (inputType.hasRank() && permType.hasRank())
1342     if (!permType.isDynamicDim(0) &&
1343         permType.getDimSize(0) != inputType.getRank())
1344       return emitOpError() << "expected permutation tensor dim 0 to have size "
1345                            << inputType.getRank()
1346                            << " (input rank) but got size "
1347                            << permType.getDimSize(0);
1348   if (inputType.hasRank() && outputType.hasRank() &&
1349       inputType.getRank() != outputType.getRank())
1350     return emitOpError()
1351            << "expected input tensor rank to equal result tensor rank";
1352   if (outputType.hasRank() && permType.hasRank())
1353     if (!permType.isDynamicDim(0) &&
1354         permType.getDimSize(0) != outputType.getRank())
1355       return emitOpError() << "expected permutation tensor dim 0 to have size "
1356                            << outputType.getRank()
1357                            << " (output rank) but got size "
1358                            << permType.getDimSize(0);
1359 
1360   SmallVector<int32_t> constantPerms;
1361   if (succeeded(getConstantPerms(constantPerms))) {
1362     // Assert that the permutation tensor has a rank, which means that the
1363     // rank has been verified above.
1364     assert(permType.hasRank() &&
1365            "Unexpectedly found permutation tensor without rank");
1366     if (!llvm::all_of(constantPerms,
1367                       [&constantPerms](int32_t s) {
1368                         return s >= 0 &&
1369                                static_cast<size_t>(s) < constantPerms.size();
1370                       }) ||
1371         !isPermutationVector(llvm::to_vector(llvm::map_range(
1372             constantPerms, [](int32_t v) -> int64_t { return v; }))))
1373       return emitOpError() << "expected valid permutation tensor";
1374 
1375     // Verify that the types of the input and output tensors are properly
1376     // permuted.
1377     if (inputType.hasRank() && outputType.hasRank()) {
1378       assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1379              inputType.getRank() == outputType.getRank());
1380 
1381       for (auto i = 0; i < outputType.getRank(); i++) {
1382         if (inputType.isDynamicDim(constantPerms[i]) ||
1383             outputType.isDynamicDim(i))
1384           continue;
1385 
1386         if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1387           return emitOpError()
1388                  << "expected output tensor dim " << i << " to match "
1389                  << "input dim " << constantPerms[i] << " with value of "
1390                  << inputType.getDimSize(constantPerms[i]);
1391       }
1392     }
1393   }
1394   return success();
1395 }
1396 
1397 LogicalResult TransposeOp::reifyResultShapes(
1398     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1399 
1400   SmallVector<int32_t> transposePerms;
1401   if (getConstantPerms(transposePerms).failed())
1402     return failure();
1403 
1404   Value input = getInput1();
1405   auto inputType = cast<TensorType>(input.getType());
1406 
1407   SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1408   for (auto dim : transposePerms) {
1409     int32_t dimInInput = transposePerms[dim];
1410     if (inputType.isDynamicDim(dimInInput))
1411       returnedDims[dim] =
1412           builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1413               .getResult();
1414     else
1415       returnedDims[dim] =
1416           builder.getIndexAttr(inputType.getDimSize(dimInInput));
1417   }
1418 
1419   reifiedReturnShapes.emplace_back(std::move(returnedDims));
1420   return success();
1421 }
1422 
1423 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1424     MLIRContext *context, ::std::optional<Location> location,
1425     GatherOp::Adaptor adaptor,
1426     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1427   llvm::SmallVector<int64_t> outputShape;
1428   outputShape.resize(3, ShapedType::kDynamic);
1429 
1430   ShapeAdaptor valuesShape(adaptor.getValues().getType());
1431   if (valuesShape.hasRank()) {
1432     outputShape[0] = valuesShape.getDimSize(0);
1433     outputShape[2] = valuesShape.getDimSize(2);
1434   }
1435 
1436   ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1437   if (indicesShape.hasRank()) {
1438     if (outputShape[0] == ShapedType::kDynamic)
1439       outputShape[0] = indicesShape.getDimSize(0);
1440     if (outputShape[1] == ShapedType::kDynamic)
1441       outputShape[1] = indicesShape.getDimSize(1);
1442   }
1443 
1444   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1445   return success();
1446 }
1447 
1448 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1449     MLIRContext *context, ::std::optional<Location> location,
1450     ResizeOp::Adaptor adaptor,
1451     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1452   llvm::SmallVector<int64_t, 4> outputShape;
1453   outputShape.resize(4, ShapedType::kDynamic);
1454 
1455   ShapeAdaptor inputShape(adaptor.getInput().getType());
1456   if (!inputShape.hasRank())
1457     return failure();
1458 
1459   outputShape[0] = inputShape.getDimSize(0);
1460   outputShape[3] = inputShape.getDimSize(3);
1461   int64_t inputHeight = inputShape.getDimSize(1);
1462   int64_t inputWidth = inputShape.getDimSize(2);
1463 
1464   if ((inputHeight == ShapedType::kDynamic) ||
1465       (inputWidth == ShapedType::kDynamic))
1466     return failure();
1467 
1468   llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1469   llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1470   llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1471 
1472   // Compute the output shape based on attributes: scale, offset, and border.
1473   outputShape[1] =
1474       (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1475        scaleInt[1]) +
1476       1;
1477 
1478   outputShape[2] =
1479       (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1480        scaleInt[3]) +
1481       1;
1482 
1483   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1484   return success();
1485 }
1486 
1487 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1488     MLIRContext *context, ::std::optional<Location> location,
1489     ScatterOp::Adaptor adaptor,
1490     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1491   llvm::SmallVector<int64_t> outputShape;
1492   outputShape.resize(3, ShapedType::kDynamic);
1493 
1494   ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1495   if (valuesInShape.hasRank()) {
1496     outputShape[0] = valuesInShape.getDimSize(0);
1497     outputShape[1] = valuesInShape.getDimSize(1);
1498     outputShape[2] = valuesInShape.getDimSize(2);
1499   }
1500 
1501   ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1502   if (indicesShape.hasRank()) {
1503     if (outputShape[0] == ShapedType::kDynamic)
1504       outputShape[0] = indicesShape.getDimSize(0);
1505   }
1506 
1507   ShapeAdaptor inputShape(adaptor.getInput().getType());
1508   if (inputShape.hasRank()) {
1509     if (outputShape[0] == ShapedType::kDynamic)
1510       outputShape[0] = inputShape.getDimSize(0);
1511     if (outputShape[2] == ShapedType::kDynamic)
1512       outputShape[2] = inputShape.getDimSize(2);
1513   }
1514 
1515   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1516   return success();
1517 }
1518 
1519 static LogicalResult ReduceInferReturnTypes(
1520     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1521     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1522   int64_t axisVal = axis.getValue().getSExtValue();
1523   if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1524     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1525     return success();
1526   }
1527 
1528   SmallVector<int64_t> outputShape;
1529   operandShape.getDims(outputShape);
1530   outputShape[axisVal] = 1;
1531   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1532   return success();
1533 }
1534 
1535 #define COMPATIBLE_RETURN_TYPES(OP)                                            \
1536   bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) {                 \
1537     if (l.size() != r.size() || l.size() != 1)                                 \
1538       return false;                                                            \
1539     if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))              \
1540       return false;                                                            \
1541     return succeeded(verifyCompatibleShape(l[0], r[0]));                       \
1542   }
1543 
1544 #define REDUCE_SHAPE_INFER(OP)                                                 \
1545   LogicalResult OP::inferReturnTypeComponents(                                 \
1546       MLIRContext *context, ::std::optional<Location> location,                \
1547       OP::Adaptor adaptor,                                                     \
1548       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
1549     Type inputType =                                                           \
1550         llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1551     ShapeAdaptor inputShape(adaptor.getInput().getType());                     \
1552     const Properties &prop = adaptor.getProperties();                          \
1553     return ReduceInferReturnTypes(inputShape, inputType, prop.axis,            \
1554                                   inferredReturnShapes);                       \
1555   }                                                                            \
1556   COMPATIBLE_RETURN_TYPES(OP)
1557 
1558 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1559 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1560 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1561 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1562 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1563 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1564 #undef REDUCE_SHAPE_INFER
1565 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1566 #undef COMPATIBLE_RETURN_TYPES
1567 
1568 template <typename T>
1569 static LogicalResult verifyReduceOp(T op) {
1570   // All TOSA reduce Ops have input, output and axis.
1571   TensorType inputType = op.getInput().getType();
1572   TensorType outputType = op.getOutput().getType();
1573   int32_t reduceAxis = op.getAxis();
1574 
1575   if (reduceAxis < 0) {
1576     op.emitOpError("reduce axis must not be negative");
1577     return failure();
1578   }
1579   if (inputType.hasRank()) {
1580     int64_t inputRank = inputType.getRank();
1581     // We allow for a special case where the input/output shape has rank 0 and
1582     // axis is also 0.
1583     if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1584       op.emitOpError("expect input tensor rank (")
1585           << inputRank << ") to be larger than reduce axis (" << reduceAxis
1586           << ")";
1587       return failure();
1588     }
1589   }
1590   if (outputType.hasRank()) {
1591     int64_t outputRank = outputType.getRank();
1592     if (inputType.hasRank() && outputRank != inputType.getRank()) {
1593       op.emitOpError(
1594           "expect output tensor rank to be equal to input tensor rank");
1595       return failure();
1596     }
1597     if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1598       op.emitOpError("expect output tensor rank (")
1599           << outputRank << ") to be larger than reduce axis (" << reduceAxis
1600           << ")";
1601       return failure();
1602     }
1603     // We can only verify the reduced dimension size to be 1 if this is not
1604     // the special case of output rank == 0.
1605     if (outputRank != 0) {
1606       auto outputShape = outputType.getShape();
1607       if (!outputType.isDynamicDim(reduceAxis) &&
1608           outputShape[reduceAxis] != 1) {
1609         op.emitOpError("expect reduced dimension size to be 1, got ")
1610             << outputShape[reduceAxis];
1611         return failure();
1612       }
1613     }
1614   }
1615   return success();
1616 }
1617 
1618 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1619 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1620 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1621 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1622 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1623 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1624 
1625 static LogicalResult NAryInferReturnTypes(
1626     const ValueShapeRange &operands,
1627     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1628   llvm::SmallVector<int64_t> outShape;
1629   if (resolveBroadcastShape(operands, outShape).failed()) {
1630     inferredReturnShapes.push_back(ShapedTypeComponents());
1631   } else {
1632     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1633   }
1634   return success();
1635 }
1636 
1637 #define NARY_SHAPE_INFER(OP)                                                   \
1638   LogicalResult OP::inferReturnTypeComponents(                                 \
1639       MLIRContext *context, ::std::optional<Location> location,                \
1640       ValueShapeRange operands, DictionaryAttr attributes,                     \
1641       OpaqueProperties properties, RegionRange regions,                        \
1642       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
1643     return NAryInferReturnTypes(operands, inferredReturnShapes);               \
1644   }
1645 
1646 NARY_SHAPE_INFER(tosa::AbsOp)
1647 NARY_SHAPE_INFER(tosa::AddOp)
1648 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1649 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1650 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1651 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1652 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1653 NARY_SHAPE_INFER(tosa::CastOp)
1654 NARY_SHAPE_INFER(tosa::CeilOp)
1655 NARY_SHAPE_INFER(tosa::ClampOp)
1656 NARY_SHAPE_INFER(tosa::ClzOp)
1657 NARY_SHAPE_INFER(tosa::CosOp)
1658 NARY_SHAPE_INFER(tosa::ExpOp)
1659 NARY_SHAPE_INFER(tosa::FloorOp)
1660 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1661 NARY_SHAPE_INFER(tosa::GreaterOp)
1662 NARY_SHAPE_INFER(tosa::IdentityOp)
1663 NARY_SHAPE_INFER(tosa::IntDivOp)
1664 NARY_SHAPE_INFER(tosa::LogOp)
1665 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1666 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1667 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1668 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1669 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1670 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1671 NARY_SHAPE_INFER(tosa::MaximumOp)
1672 NARY_SHAPE_INFER(tosa::MinimumOp)
1673 NARY_SHAPE_INFER(tosa::MulOp)
1674 NARY_SHAPE_INFER(tosa::NegateOp)
1675 NARY_SHAPE_INFER(tosa::PowOp)
1676 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1677 NARY_SHAPE_INFER(tosa::RescaleOp)
1678 NARY_SHAPE_INFER(tosa::ReverseOp)
1679 NARY_SHAPE_INFER(tosa::RsqrtOp)
1680 NARY_SHAPE_INFER(tosa::SinOp)
1681 NARY_SHAPE_INFER(tosa::SelectOp)
1682 NARY_SHAPE_INFER(tosa::SubOp)
1683 NARY_SHAPE_INFER(tosa::TanhOp)
1684 NARY_SHAPE_INFER(tosa::ErfOp)
1685 NARY_SHAPE_INFER(tosa::SigmoidOp)
1686 #undef PRED_SHAPE_INFER
1687 
1688 static LogicalResult poolingInferReturnTypes(
1689     ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1690     ArrayRef<int64_t> pad,
1691     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1692   llvm::SmallVector<int64_t> outputShape;
1693   outputShape.resize(4, ShapedType::kDynamic);
1694 
1695   // We only know the rank if the input type is unranked.
1696   if (!inputShape) {
1697     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1698     return success();
1699   }
1700 
1701   // Batch and number of channels are identical for pooling layer.
1702   outputShape[0] = inputShape.getDimSize(0);
1703   outputShape[3] = inputShape.getDimSize(3);
1704 
1705   int64_t height = inputShape.getDimSize(1);
1706   int64_t width = inputShape.getDimSize(2);
1707 
1708   if (!ShapedType::isDynamic(height)) {
1709     int64_t padded = height + pad[0] + pad[1] - kernel[0];
1710     outputShape[1] = padded / stride[0] + 1;
1711   }
1712 
1713   if (!ShapedType::isDynamic(width)) {
1714     int64_t padded = width + pad[2] + pad[3] - kernel[1];
1715     outputShape[2] = padded / stride[1] + 1;
1716   }
1717 
1718   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1719   return success();
1720 }
1721 
1722 LogicalResult Conv2DOp::inferReturnTypeComponents(
1723     MLIRContext *context, ::std::optional<Location> location,
1724     Conv2DOp::Adaptor adaptor,
1725     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1726   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1727 
1728   int64_t inputWidth = ShapedType::kDynamic;
1729   int64_t inputHeight = ShapedType::kDynamic;
1730   int64_t weightWidth = ShapedType::kDynamic;
1731   int64_t weightHeight = ShapedType::kDynamic;
1732 
1733   // Input shape describes input width/height and batch.
1734 
1735   ShapeAdaptor inputShape(adaptor.getInput().getType());
1736   if (inputShape.hasRank()) {
1737     outputShape[0] = inputShape.getDimSize(0);
1738     inputHeight = inputShape.getDimSize(1);
1739     inputWidth = inputShape.getDimSize(2);
1740   }
1741 
1742   // Weight shapes describes the filter width/height and the output channels.
1743   ShapeAdaptor weightShape(adaptor.getWeight().getType());
1744   if (weightShape.hasRank()) {
1745     outputShape[3] = weightShape.getDimSize(0);
1746     weightHeight = weightShape.getDimSize(1);
1747     weightWidth = weightShape.getDimSize(2);
1748   }
1749 
1750   // Bias shape can describe the output channels.
1751   ShapeAdaptor biasShape(adaptor.getBias().getType());
1752   if (biasShape.hasRank()) {
1753     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1754                          ? biasShape.getDimSize(0)
1755                          : outputShape[3];
1756   }
1757 
1758   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1759   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1760   llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1761 
1762   if (!ShapedType::isDynamic(inputHeight) &&
1763       !ShapedType::isDynamic(weightHeight)) {
1764     int64_t inputSize = inputHeight + padding[0] + padding[1];
1765     int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1766     int64_t unstridedResult = inputSize - filterSize + 1;
1767     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1768   }
1769 
1770   if (!ShapedType::isDynamic(inputWidth) &&
1771       !ShapedType::isDynamic(weightWidth)) {
1772     int64_t inputSize = inputWidth + padding[2] + padding[3];
1773     int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1774     int64_t unstridedResult = inputSize - filterSize + 1;
1775     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1776   }
1777 
1778   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1779   return success();
1780 }
1781 
1782 LogicalResult Conv2DOp::verify() {
1783   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1784     return failure();
1785   return success();
1786 }
1787 
1788 LogicalResult Conv3DOp::inferReturnTypeComponents(
1789     MLIRContext *context, ::std::optional<Location> location,
1790     Conv3DOp::Adaptor adaptor,
1791     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1792   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1793 
1794   int64_t inputWidth = ShapedType::kDynamic;
1795   int64_t inputHeight = ShapedType::kDynamic;
1796   int64_t inputDepth = ShapedType::kDynamic;
1797 
1798   int64_t weightWidth = ShapedType::kDynamic;
1799   int64_t weightHeight = ShapedType::kDynamic;
1800   int64_t weightDepth = ShapedType::kDynamic;
1801 
1802   // Input shape describes input width/height and batch.
1803   ShapeAdaptor inputShape(adaptor.getInput().getType());
1804   if (inputShape.hasRank()) {
1805     outputShape[0] = inputShape.getDimSize(0);
1806     inputDepth = inputShape.getDimSize(1);
1807     inputHeight = inputShape.getDimSize(2);
1808     inputWidth = inputShape.getDimSize(3);
1809   }
1810 
1811   // Weight shapes describes the filter width/height and the output channels.
1812   ShapeAdaptor weightShape(adaptor.getWeight().getType());
1813   if (weightShape.hasRank()) {
1814     outputShape[4] = weightShape.getDimSize(0);
1815     weightDepth = weightShape.getDimSize(1);
1816     weightHeight = weightShape.getDimSize(2);
1817     weightWidth = weightShape.getDimSize(3);
1818   }
1819 
1820   // Bias shape can describe the output channels.
1821   ShapeAdaptor biasShape(adaptor.getBias().getType());
1822   if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1823     outputShape[4] = biasShape.getDimSize(0);
1824   }
1825 
1826   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1827   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1828   llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1829 
1830   if (!ShapedType::isDynamic(inputDepth) &&
1831       !ShapedType::isDynamic(weightDepth)) {
1832     int32_t inputSize = inputDepth + pad[0] + pad[1];
1833     int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1834     int32_t unstridedResult = inputSize - filterSize + 1;
1835     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1836   }
1837 
1838   if (!ShapedType::isDynamic(inputHeight) &&
1839       !ShapedType::isDynamic(weightHeight)) {
1840     int32_t inputSize = inputHeight + pad[2] + pad[3];
1841     int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1842     int32_t unstridedResult = inputSize - filterSize + 1;
1843     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1844   }
1845 
1846   if (!ShapedType::isDynamic(inputWidth) &&
1847       !ShapedType::isDynamic(weightWidth)) {
1848     int32_t inputSize = inputWidth + pad[4] + pad[5];
1849     int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1850     int32_t unstridedResult = inputSize - filterSize + 1;
1851     outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1852   }
1853 
1854   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1855   return success();
1856 }
1857 
1858 LogicalResult Conv3DOp::verify() {
1859   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1860     return failure();
1861   return success();
1862 }
1863 
1864 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1865     MLIRContext *context, ::std::optional<Location> location,
1866     AvgPool2dOp::Adaptor adaptor,
1867     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1868   ShapeAdaptor inputShape(adaptor.getInput().getType());
1869   const Properties &prop = adaptor.getProperties();
1870   return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1871                                  inferredReturnShapes);
1872 }
1873 
1874 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1875     MLIRContext *context, ::std::optional<Location> location,
1876     MaxPool2dOp::Adaptor adaptor,
1877     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1878   ShapeAdaptor inputShape(adaptor.getInput().getType());
1879   const Properties &prop = adaptor.getProperties();
1880   return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1881                                  inferredReturnShapes);
1882 }
1883 
1884 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1885     MLIRContext *context, ::std::optional<Location> location,
1886     DepthwiseConv2DOp::Adaptor adaptor,
1887     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1888   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1889 
1890   int64_t inputWidth = ShapedType::kDynamic;
1891   int64_t inputHeight = ShapedType::kDynamic;
1892   int64_t inputChannels = ShapedType::kDynamic;
1893 
1894   int64_t weightWidth = ShapedType::kDynamic;
1895   int64_t weightHeight = ShapedType::kDynamic;
1896   int64_t depthChannels = ShapedType::kDynamic;
1897 
1898   // Input shape describes input width/height and batch.
1899   ShapeAdaptor inputShape(adaptor.getInput().getType());
1900   if (inputShape.hasRank()) {
1901     outputShape[0] = inputShape.getDimSize(0);
1902     inputHeight = inputShape.getDimSize(1);
1903     inputWidth = inputShape.getDimSize(2);
1904     inputChannels = inputShape.getDimSize(3);
1905   }
1906 
1907   // Weight shapes describes the filter width/height and the output channels.
1908   ShapeAdaptor weightShape(adaptor.getWeight().getType());
1909   if (weightShape.hasRank()) {
1910     weightHeight = weightShape.getDimSize(0);
1911     weightWidth = weightShape.getDimSize(1);
1912     inputChannels = ShapedType::isDynamic(inputChannels)
1913                         ? weightShape.getDimSize(2)
1914                         : inputChannels;
1915     depthChannels = weightShape.getDimSize(3);
1916   }
1917 
1918   // If both inputChannels and depthChannels are available we can determine
1919   // the output channels.
1920   if (!ShapedType::isDynamic(inputChannels) &&
1921       !ShapedType::isDynamic(depthChannels)) {
1922     outputShape[3] = inputChannels * depthChannels;
1923   }
1924 
1925   // Bias shape can describe the output channels.
1926   ShapeAdaptor biasShape(adaptor.getBias().getType());
1927   if (biasShape.hasRank()) {
1928     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1929                          ? biasShape.getDimSize(0)
1930                          : outputShape[3];
1931   }
1932 
1933   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1934   llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1935   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1936 
1937   if (!ShapedType::isDynamic(inputHeight) &&
1938       !ShapedType::isDynamic(weightHeight)) {
1939     int64_t inputSize = inputHeight + padding[0] + padding[1];
1940     int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1941     int64_t unstridedResult = inputSize - filterSize + 1;
1942     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1943   }
1944 
1945   if (!ShapedType::isDynamic(inputWidth) &&
1946       !ShapedType::isDynamic(weightWidth)) {
1947     int64_t inputSize = inputWidth + padding[2] + padding[3];
1948     int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1949     int64_t unstridedResult = inputSize - filterSize + 1;
1950     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1951   }
1952 
1953   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1954   return success();
1955 }
1956 
1957 LogicalResult DepthwiseConv2DOp::verify() {
1958   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1959     return failure();
1960   return success();
1961 }
1962 
1963 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1964     MLIRContext *context, ::std::optional<Location> location,
1965     TransposeConv2DOp::Adaptor adaptor,
1966     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1967   // outputShape is mutable.
1968   llvm::SmallVector<int64_t> outputShape =
1969       convertToMlirShape(adaptor.getOutShape());
1970 
1971   int64_t inputWidth = ShapedType::kDynamic;
1972   int64_t inputHeight = ShapedType::kDynamic;
1973   int64_t weightWidth = ShapedType::kDynamic;
1974   int64_t weightHeight = ShapedType::kDynamic;
1975 
1976   // Input shape describes input width/height and batch.
1977   ShapeAdaptor inputShape(adaptor.getInput().getType());
1978   if (inputShape.hasRank()) {
1979     outputShape[0] = ShapedType::isDynamic(outputShape[0])
1980                          ? inputShape.getDimSize(0)
1981                          : outputShape[0];
1982     inputHeight = inputShape.getDimSize(1);
1983     inputWidth = inputShape.getDimSize(2);
1984   }
1985 
1986   // Weight shapes describes the filter width/height and the output channels.
1987   ShapeAdaptor weightShape(adaptor.getFilter().getType());
1988   if (weightShape.hasRank()) {
1989     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1990                          ? weightShape.getDimSize(0)
1991                          : outputShape[3];
1992     weightHeight = weightShape.getDimSize(1);
1993     weightWidth = weightShape.getDimSize(2);
1994   }
1995 
1996   // Bias shape can describe the output channels.
1997   ShapeAdaptor biasShape(adaptor.getInput().getType());
1998   if (biasShape.hasRank()) {
1999     outputShape[3] = ShapedType::isDynamic(outputShape[3])
2000                          ? biasShape.getDimSize(0)
2001                          : outputShape[3];
2002   }
2003 
2004   llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
2005   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2006 
2007   if (!ShapedType::isDynamic(inputHeight) &&
2008       !ShapedType::isDynamic(weightHeight)) {
2009     int64_t calculateSize =
2010         (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2011     outputShape[1] =
2012         ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2013   }
2014 
2015   if (!ShapedType::isDynamic(inputWidth) &&
2016       !ShapedType::isDynamic(weightWidth)) {
2017     int64_t calculateSize =
2018         (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2019     outputShape[2] =
2020         ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2021   }
2022 
2023   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2024   return success();
2025 }
2026 
2027 LogicalResult TransposeConv2DOp::verify() {
2028   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2029     return failure();
2030   return success();
2031 }
2032 
2033 LogicalResult IfOp::inferReturnTypeComponents(
2034     MLIRContext *context, ::std::optional<Location> location,
2035     IfOp::Adaptor adaptor,
2036     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2037   llvm::SmallVector<tosa::YieldOp> yieldOps;
2038   for (Region *region : adaptor.getRegions()) {
2039     for (auto &block : *region)
2040       if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2041         yieldOps.push_back(returnOp);
2042   }
2043 
2044   if (yieldOps.empty())
2045     return failure();
2046 
2047   // Get the initial type information for the yield op.
2048   llvm::SmallVector<ValueKnowledge> resultKnowledge;
2049   resultKnowledge.reserve(yieldOps.front().getNumOperands());
2050   for (auto operand : yieldOps.front().getOperands()) {
2051     resultKnowledge.push_back(
2052         ValueKnowledge::getKnowledgeFromType(operand.getType()));
2053   }
2054 
2055   for (auto yieldOp : yieldOps) {
2056     if (resultKnowledge.size() != yieldOp.getNumOperands())
2057       return failure();
2058 
2059     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2060       int32_t index = it.index();
2061       auto meet = ValueKnowledge::meet(
2062           resultKnowledge[index],
2063           ValueKnowledge::getKnowledgeFromType(it.value().getType()));
2064       if (!meet)
2065         continue;
2066       resultKnowledge[index] = meet;
2067     }
2068   }
2069 
2070   for (const ValueKnowledge &result : resultKnowledge) {
2071     inferredReturnShapes.push_back(result.getShapedTypeComponents());
2072   }
2073 
2074   return success();
2075 }
2076 
2077 LogicalResult WhileOp::inferReturnTypeComponents(
2078     MLIRContext *context, ::std::optional<Location> location,
2079     WhileOp::Adaptor adaptor,
2080     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2081   llvm::SmallVector<tosa::YieldOp> yieldOps;
2082   for (auto &block : adaptor.getBody())
2083     if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2084       yieldOps.push_back(returnOp);
2085 
2086   // TOSA's while must have a tosa.yield as its terminator. If not found this
2087   // tosa.while is invalid.
2088   if (yieldOps.empty())
2089     return failure();
2090 
2091   // Get the initial type information from the operand types.
2092   llvm::SmallVector<ValueKnowledge> resultKnowledge;
2093   resultKnowledge.reserve(yieldOps.front().getNumOperands());
2094   for (auto operand : yieldOps.front().getOperands()) {
2095     resultKnowledge.push_back(
2096         ValueKnowledge::getKnowledgeFromType(operand.getType()));
2097   }
2098 
2099   for (auto yieldOp : yieldOps) {
2100     if (resultKnowledge.size() != yieldOp.getNumOperands())
2101       return failure();
2102 
2103     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2104       int32_t index = it.index();
2105       if (auto meet = ValueKnowledge::meet(
2106               resultKnowledge[index],
2107               ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
2108         resultKnowledge[index] = meet;
2109       }
2110     }
2111   }
2112 
2113   for (const ValueKnowledge &result : resultKnowledge) {
2114     inferredReturnShapes.push_back(result.getShapedTypeComponents());
2115   }
2116 
2117   return success();
2118 }
2119 
2120 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
2121   if (auto vt = llvm::dyn_cast<VectorType>(getType()))
2122     return llvm::to_vector<4>(vt.getShape());
2123   return std::nullopt;
2124 }
2125 
2126 // parse and print of IfOp refer to the implementation of SCF dialect.
2127 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2128   // Create the regions for 'then'.
2129   result.regions.reserve(2);
2130   Region *thenRegion = result.addRegion();
2131   Region *elseRegion = result.addRegion();
2132 
2133   auto &builder = parser.getBuilder();
2134   OpAsmParser::UnresolvedOperand cond;
2135   // Create a i1 tensor type for the boolean condition.
2136   Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
2137   if (parser.parseOperand(cond) ||
2138       parser.resolveOperand(cond, i1Type, result.operands))
2139     return failure();
2140   // Parse optional results type list.
2141   if (parser.parseOptionalArrowTypeList(result.types))
2142     return failure();
2143   // Parse the 'then' region.
2144   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2145     return failure();
2146 
2147   // If we find an 'else' keyword then parse the 'else' region.
2148   if (!parser.parseOptionalKeyword("else")) {
2149     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2150       return failure();
2151   }
2152 
2153   // Parse the optional attribute list.
2154   if (parser.parseOptionalAttrDict(result.attributes))
2155     return failure();
2156   return success();
2157 }
2158 
2159 void IfOp::print(OpAsmPrinter &p) {
2160   bool printBlockTerminators = false;
2161 
2162   p << " " << getCond();
2163   if (!getResults().empty()) {
2164     p << " -> (" << getResultTypes() << ")";
2165     // Print yield explicitly if the op defines values.
2166     printBlockTerminators = true;
2167   }
2168   p << ' ';
2169   p.printRegion(getThenBranch(),
2170                 /*printEntryBlockArgs=*/false,
2171                 /*printBlockTerminators=*/printBlockTerminators);
2172 
2173   // Print the 'else' regions if it exists and has a block.
2174   auto &elseRegion = getElseBranch();
2175   if (!elseRegion.empty()) {
2176     p << " else ";
2177     p.printRegion(elseRegion,
2178                   /*printEntryBlockArgs=*/false,
2179                   /*printBlockTerminators=*/printBlockTerminators);
2180   }
2181 
2182   p.printOptionalAttrDict((*this)->getAttrs());
2183 }
2184 
2185 LogicalResult ReverseOp::verify() {
2186   TensorType inputType = getInput1().getType();
2187   TensorType outputType = getOutput().getType();
2188   int32_t reverseAxis = getAxis();
2189 
2190   if (reverseAxis < 0)
2191     return emitOpError("expected non-negative reverse axis");
2192   if (inputType.hasRank()) {
2193     int64_t inputRank = inputType.getRank();
2194     // We allow for a special case where the input/output shape has rank 0 and
2195     // axis is also 0.
2196     if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
2197       return emitOpError("expect input tensor rank (")
2198              << inputRank << ") to be larger than reverse axis (" << reverseAxis
2199              << ")";
2200   }
2201   if (outputType.hasRank()) {
2202     int64_t outputRank = outputType.getRank();
2203     if (inputType.hasRank() && outputRank != inputType.getRank())
2204       return emitOpError(
2205           "expect output tensor rank to be equal to input tensor rank");
2206     if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
2207       return emitOpError("expect output tensor rank (")
2208              << outputRank << ") to be larger than reverse axis ("
2209              << reverseAxis << ")";
2210   }
2211   return success();
2212 }
2213 
2214 // parse and print of WhileOp refer to the implementation of SCF dialect.
2215 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2216   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2217   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2218   Region *cond = result.addRegion();
2219   Region *body = result.addRegion();
2220 
2221   OptionalParseResult listResult =
2222       parser.parseOptionalAssignmentList(regionArgs, operands);
2223   if (listResult.has_value() && failed(listResult.value()))
2224     return failure();
2225 
2226   FunctionType functionType;
2227   SMLoc typeLoc = parser.getCurrentLocation();
2228   if (failed(parser.parseColonType(functionType)))
2229     return failure();
2230 
2231   result.addTypes(functionType.getResults());
2232 
2233   if (functionType.getNumInputs() != operands.size()) {
2234     return parser.emitError(typeLoc)
2235            << "expected as many input types as operands "
2236            << "(expected " << operands.size() << " got "
2237            << functionType.getNumInputs() << ")";
2238   }
2239 
2240   // Resolve input operands.
2241   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2242                                     parser.getCurrentLocation(),
2243                                     result.operands)))
2244     return failure();
2245 
2246   // Propagate the types into the region arguments.
2247   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2248     regionArgs[i].type = functionType.getInput(i);
2249 
2250   return failure(parser.parseRegion(*cond, regionArgs) ||
2251                  parser.parseKeyword("do") || parser.parseRegion(*body) ||
2252                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
2253 }
2254 
2255 static void printInitializationList(OpAsmPrinter &parser,
2256                                     Block::BlockArgListType blocksArgs,
2257                                     ValueRange initializers,
2258                                     StringRef prefix = "") {
2259   assert(blocksArgs.size() == initializers.size() &&
2260          "expected same length of arguments and initializers");
2261   if (initializers.empty())
2262     return;
2263 
2264   parser << prefix << '(';
2265   llvm::interleaveComma(
2266       llvm::zip(blocksArgs, initializers), parser,
2267       [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
2268   parser << ")";
2269 }
2270 
2271 void WhileOp::print(OpAsmPrinter &parser) {
2272   printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2273                           " ");
2274   parser << " : ";
2275   parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2276   parser << ' ';
2277   parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2278   parser << " do ";
2279   parser.printRegion(getBody());
2280   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2281 }
2282 
2283 //===----------------------------------------------------------------------===//
2284 // TOSA Shape and Shape Operators Helper functions.
2285 //===----------------------------------------------------------------------===//
2286 
2287 bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
2288   return mlir::isa<tosa::shapeType>(t);
2289 }
2290 
2291 LogicalResult
2292 mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
2293                               int rank) {
2294   if (rank < 0)
2295     return emitError() << "invalid rank (must be >= 0): " << rank;
2296   return success();
2297 }
2298 
2299 LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
2300   for (auto v : op->getOperands()) {
2301     if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
2302       Operation *definingOp = v.getDefiningOp();
2303       if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
2304         return op->emitOpError("shape operand is not compile time resolvable");
2305       }
2306     }
2307   }
2308   return success();
2309 }
2310 
2311 LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
2312   for (auto type : op->getOperandTypes()) {
2313     if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2314       return op->emitOpError("must have operands with tosa shape type");
2315     }
2316   }
2317   for (auto type : op->getResultTypes()) {
2318     if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2319       return op->emitOpError("must have result with tosa shape type");
2320     }
2321   }
2322   return success();
2323 }
2324 
2325 LogicalResult
2326 OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
2327   if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
2328       failed(verifyTosaShapeOperator(op)))
2329     return failure();
2330 
2331   // delegate function that returns rank of shape type
2332   auto getRank = [](const Type type) {
2333     return mlir::cast<mlir::tosa::shapeType>(type).getRank();
2334   };
2335   auto operandTypes = op->getOperandTypes();
2336   auto resultTypes = op->getResultTypes();
2337 
2338   auto rank = getRank(*op->getOperandTypes().begin());
2339   for (auto type : operandTypes) {
2340     if (getRank(type) != rank) {
2341       return op->emitOpError("operands don't have matching ranks");
2342     }
2343   }
2344   for (auto type : resultTypes) {
2345     if (getRank(type) != rank) {
2346       return op->emitOpError("result shape has different rank than operands");
2347     }
2348   }
2349   return success();
2350 }
2351 
2352 //===----------------------------------------------------------------------===//
2353 // TOSA Shape Operators verify functions.
2354 //===----------------------------------------------------------------------===//
2355 
2356 LogicalResult tosa::ConstShapeOp::verify() {
2357   // check that number of elements in value attr equal to rank of result shape
2358   auto count = getValue().getNumElements();
2359   auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
2360   if (!(count == rank || (count == 1 && rank == 0))) {
2361     return emitOpError("expect number of elements in attribute value (")
2362            << count << ") to be equal to the rank (" << rank
2363            << ") for the result shape type";
2364   }
2365   return success();
2366 }
2367 
2368 //===----------------------------------------------------------------------===//
2369 // TOSA Attribute Definitions.
2370 //===----------------------------------------------------------------------===//
2371 
2372 #define GET_ATTRDEF_CLASSES
2373 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2374 
2375 //===----------------------------------------------------------------------===//
2376 // TOSA Type Definitions.
2377 //===----------------------------------------------------------------------===//
2378 #define GET_TYPEDEF_CLASSES
2379 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
2380 
2381 //===----------------------------------------------------------------------===//
2382 // TOSA Operator Definitions.
2383 //===----------------------------------------------------------------------===//
2384 
2385 #define GET_OP_CLASSES
2386 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
2387