xref: /llvm-project/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1b2812113SSuraj Sudhir //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2b2812113SSuraj Sudhir //
3b2812113SSuraj Sudhir // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b2812113SSuraj Sudhir // See https://llvm.org/LICENSE.txt for license information.
5b2812113SSuraj Sudhir // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b2812113SSuraj Sudhir //
7b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
8b2812113SSuraj Sudhir //
9b2812113SSuraj Sudhir // \file
10b2812113SSuraj Sudhir // This file implements the TOSA Specification:
11b2812113SSuraj Sudhir // https://developer.mlplatform.org/w/tosa/
12b2812113SSuraj Sudhir //
13b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
14b2812113SSuraj Sudhir 
15b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16513cdb82SJustin Fargnoli #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
17852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/Quant.h"
182b2ebb6fSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h"
19b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
201b00b94fSRob Suderman #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
218190369eSFelix Schneider #include "mlir/Dialect/Utils/IndexingUtils.h"
2209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
23f1182bd6SMogball #include "mlir/IR/DialectImplementation.h"
245a4e7760SRob Suderman #include "mlir/IR/Matchers.h"
252d0ba5e1SRob Suderman #include "mlir/IR/PatternMatch.h"
26b73e8325Snot-jenni #include "mlir/IR/TypeUtilities.h"
27057fc8e7SAmanda Tang #include "mlir/Interfaces/InferTypeOpInterface.h"
28b2812113SSuraj Sudhir #include "mlir/Transforms/InliningUtils.h"
29d89a0a65SAviad Cohen #include "llvm/ADT/APFloat.h"
30b0532286SRob Suderman #include "llvm/ADT/DenseMap.h"
31f1182bd6SMogball #include "llvm/ADT/TypeSwitch.h"
32b2812113SSuraj Sudhir 
33e6eb94d3SLongsheng Mou #include <numeric>
34e6eb94d3SLongsheng Mou 
35b2812113SSuraj Sudhir using namespace mlir;
36b2812113SSuraj Sudhir using namespace mlir::tosa;
37b2812113SSuraj Sudhir 
38485cc55eSStella Laurenzo #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
397e622b61SJerry-Ge #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
40485cc55eSStella Laurenzo 
41b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
42d883a02aSMogball // Tosa dialect interface includes.
43b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
44f1182bd6SMogball 
45b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46b2812113SSuraj Sudhir 
47b2812113SSuraj Sudhir namespace {
481a7e8b90SJacques Pienaar #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
491a7e8b90SJacques Pienaar 
50b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
51b2812113SSuraj Sudhir // Dialect Function Inliner Interface.
52b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
53b2812113SSuraj Sudhir struct TosaInlinerInterface : public DialectInlinerInterface {
54b2812113SSuraj Sudhir   using DialectInlinerInterface::DialectInlinerInterface;
55b2812113SSuraj Sudhir 
56b2812113SSuraj Sudhir   //===--------------------------------------------------------------------===//
57b2812113SSuraj Sudhir   // Analysis Hooks.
58b2812113SSuraj Sudhir   //===--------------------------------------------------------------------===//
59b2812113SSuraj Sudhir 
60b2812113SSuraj Sudhir   /// All operations can be inlined by default.
61b2812113SSuraj Sudhir   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
624d67b278SJeff Niu                        IRMapping &map) const final {
63b2812113SSuraj Sudhir     return true;
64b2812113SSuraj Sudhir   }
65b2812113SSuraj Sudhir 
66b2812113SSuraj Sudhir   /// All regions with If and While parent operators can be inlined.
67b2812113SSuraj Sudhir   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
684d67b278SJeff Niu                        IRMapping &map) const final {
69b2812113SSuraj Sudhir     return (isa<tosa::IfOp>(dest->getParentOp()) ||
70b2812113SSuraj Sudhir             isa<tosa::WhileOp>(dest->getParentOp()));
71b2812113SSuraj Sudhir   }
72b2812113SSuraj Sudhir };
731a7e8b90SJacques Pienaar 
741a7e8b90SJacques Pienaar /// This class implements the bytecode interface for the Tosa dialect.
751a7e8b90SJacques Pienaar struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
761a7e8b90SJacques Pienaar   TosaDialectBytecodeInterface(Dialect *dialect)
771a7e8b90SJacques Pienaar       : BytecodeDialectInterface(dialect) {}
781a7e8b90SJacques Pienaar 
791a7e8b90SJacques Pienaar   //===--------------------------------------------------------------------===//
801a7e8b90SJacques Pienaar   // Attributes
811a7e8b90SJacques Pienaar 
821a7e8b90SJacques Pienaar   Attribute readAttribute(DialectBytecodeReader &reader) const override {
831a7e8b90SJacques Pienaar     return ::readAttribute(getContext(), reader);
841a7e8b90SJacques Pienaar   }
851a7e8b90SJacques Pienaar 
861a7e8b90SJacques Pienaar   LogicalResult writeAttribute(Attribute attr,
871a7e8b90SJacques Pienaar                                DialectBytecodeWriter &writer) const override {
881a7e8b90SJacques Pienaar     return ::writeAttribute(attr, writer);
891a7e8b90SJacques Pienaar   }
901a7e8b90SJacques Pienaar 
911a7e8b90SJacques Pienaar   //===--------------------------------------------------------------------===//
921a7e8b90SJacques Pienaar   // Types
931a7e8b90SJacques Pienaar 
941a7e8b90SJacques Pienaar   Type readType(DialectBytecodeReader &reader) const override {
951a7e8b90SJacques Pienaar     return ::readType(getContext(), reader);
961a7e8b90SJacques Pienaar   }
971a7e8b90SJacques Pienaar 
981a7e8b90SJacques Pienaar   LogicalResult writeType(Type type,
991a7e8b90SJacques Pienaar                           DialectBytecodeWriter &writer) const override {
1001a7e8b90SJacques Pienaar     return ::writeType(type, writer);
1011a7e8b90SJacques Pienaar   }
1021a7e8b90SJacques Pienaar 
1031a7e8b90SJacques Pienaar   void writeVersion(DialectBytecodeWriter &writer) const final {
1041a7e8b90SJacques Pienaar     // TODO: Populate.
1051a7e8b90SJacques Pienaar   }
1061a7e8b90SJacques Pienaar 
1071a7e8b90SJacques Pienaar   std::unique_ptr<DialectVersion>
1081a7e8b90SJacques Pienaar   readVersion(DialectBytecodeReader &reader) const final {
1091a7e8b90SJacques Pienaar     // TODO: Populate
1101a7e8b90SJacques Pienaar     reader.emitError("Dialect does not support versioning");
1111a7e8b90SJacques Pienaar     return nullptr;
1121a7e8b90SJacques Pienaar   }
1131a7e8b90SJacques Pienaar 
1141a7e8b90SJacques Pienaar   LogicalResult upgradeFromVersion(Operation *topLevelOp,
115d2f06769SMehdi Amini                                    const DialectVersion &version) const final {
1161a7e8b90SJacques Pienaar     return success();
1171a7e8b90SJacques Pienaar   }
1181a7e8b90SJacques Pienaar };
1191a7e8b90SJacques Pienaar 
120be0a7e9fSMehdi Amini } // namespace
121b2812113SSuraj Sudhir 
122b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
123b2812113SSuraj Sudhir // TOSA control flow support.
124b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
125b2812113SSuraj Sudhir 
126b2812113SSuraj Sudhir /// Returns the while loop body.
1279b5ef2beSMatthias Springer SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
128b2812113SSuraj Sudhir 
129b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
130b2812113SSuraj Sudhir // Tosa dialect initialization.
131b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
132b2812113SSuraj Sudhir 
133b2812113SSuraj Sudhir void TosaDialect::initialize() {
134f09db6a3SJerry-Ge   addTypes<
135f09db6a3SJerry-Ge #define GET_TYPEDEF_LIST
136f09db6a3SJerry-Ge #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
137f09db6a3SJerry-Ge       >();
138b2812113SSuraj Sudhir   addOperations<
139b2812113SSuraj Sudhir #define GET_OP_LIST
140b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
141b2812113SSuraj Sudhir       >();
142f1182bd6SMogball   addAttributes<
143f1182bd6SMogball #define GET_ATTRDEF_LIST
144f1182bd6SMogball #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
145f1182bd6SMogball       >();
1461a7e8b90SJacques Pienaar   addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
147513cdb82SJustin Fargnoli   declarePromisedInterfaces<
148513cdb82SJustin Fargnoli       mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
14982383d5fSTai Ly       ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
150513cdb82SJustin Fargnoli       LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
151513cdb82SJustin Fargnoli       LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
152513cdb82SJustin Fargnoli       BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
153513cdb82SJustin Fargnoli       NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
154513cdb82SJustin Fargnoli       GreaterEqualOp, MatMulOp>();
155b2812113SSuraj Sudhir }
156b2812113SSuraj Sudhir 
157db9713cdSStella Laurenzo Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
158db9713cdSStella Laurenzo                                             Type type, Location loc) {
159db9713cdSStella Laurenzo   // Tosa dialect constants only support ElementsAttr unlike standard dialect
160db9713cdSStella Laurenzo   // constant which supports all attributes.
161f09db6a3SJerry-Ge   if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
162f09db6a3SJerry-Ge     return builder.create<tosa::ConstShapeOp>(
163f09db6a3SJerry-Ge         loc, type, llvm::cast<DenseIntElementsAttr>(value));
164f09db6a3SJerry-Ge   }
165c1fa60b4STres Popp   if (llvm::isa<ElementsAttr>(value))
166c1fa60b4STres Popp     return builder.create<tosa::ConstOp>(loc, type,
167c1fa60b4STres Popp                                          llvm::cast<ElementsAttr>(value));
168db9713cdSStella Laurenzo   return nullptr;
169db9713cdSStella Laurenzo }
170db9713cdSStella Laurenzo 
171db9713cdSStella Laurenzo //===----------------------------------------------------------------------===//
172af972f01STai Ly // Parsers and printers
173af972f01STai Ly //===----------------------------------------------------------------------===//
174af972f01STai Ly 
175af972f01STai Ly ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
176af972f01STai Ly                                         Attribute &attr) {
177af972f01STai Ly   if (succeeded(parser.parseOptionalEqual())) {
178af972f01STai Ly     if (failed(parser.parseAttribute(attr))) {
179af972f01STai Ly       return parser.emitError(parser.getCurrentLocation())
180af972f01STai Ly              << "expected attribute";
181af972f01STai Ly     }
182a5757c5bSChristian Sigg     if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
183af972f01STai Ly       typeAttr = TypeAttr::get(typedAttr.getType());
184af972f01STai Ly     }
185af972f01STai Ly     return success();
186af972f01STai Ly   }
187af972f01STai Ly 
188af972f01STai Ly   Type type;
189af972f01STai Ly   if (failed(parser.parseColonType(type))) {
190af972f01STai Ly     return parser.emitError(parser.getCurrentLocation()) << "expected type";
191af972f01STai Ly   }
192af972f01STai Ly   typeAttr = TypeAttr::get(type);
193af972f01STai Ly 
194af972f01STai Ly   return success();
195af972f01STai Ly }
196af972f01STai Ly 
197af972f01STai Ly void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
198af972f01STai Ly                                  Attribute attr) {
199af972f01STai Ly   bool needsSpace = false;
200a5757c5bSChristian Sigg   auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
201af972f01STai Ly   if (!typedAttr || typedAttr.getType() != type.getValue()) {
202af972f01STai Ly     p << ": ";
203af972f01STai Ly     p.printAttribute(type);
204af972f01STai Ly     needsSpace = true; // subsequent attr value needs a space separator
205af972f01STai Ly   }
206af972f01STai Ly   if (attr) {
207af972f01STai Ly     if (needsSpace)
208af972f01STai Ly       p << ' ';
209af972f01STai Ly     p << "= ";
210af972f01STai Ly     p.printAttribute(attr);
211af972f01STai Ly   }
212af972f01STai Ly }
213af972f01STai Ly 
214af972f01STai Ly //===----------------------------------------------------------------------===//
215b2812113SSuraj Sudhir // TOSA Operator Verifiers.
216b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
217b2812113SSuraj Sudhir 
218d2353695SPeiming Liu template <typename T>
219d2353695SPeiming Liu static LogicalResult verifyConvOp(T op) {
220b2812113SSuraj Sudhir   // All TOSA conv ops have an input() and weight().
221c1fa60b4STres Popp   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
222360a03c9SJack Frankland 
223360a03c9SJack Frankland   RankedTensorType weightType;
224360a03c9SJack Frankland   if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
225360a03c9SJack Frankland     weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
226360a03c9SJack Frankland   else
227360a03c9SJack Frankland     weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
228b2812113SSuraj Sudhir 
229b2812113SSuraj Sudhir   // Must be ranked tensor types
230afb05823SMehdi Amini   if (!inputType) {
23113448db0SJacques Pienaar     op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
232b2812113SSuraj Sudhir     return failure();
233afb05823SMehdi Amini   }
234afb05823SMehdi Amini   if (!weightType) {
235360a03c9SJack Frankland     if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
236360a03c9SJack Frankland       op.emitOpError("expect a ranked tensor for filter, got ")
237360a03c9SJack Frankland           << op.getFilter();
238360a03c9SJack Frankland     } else {
239360a03c9SJack Frankland       op.emitOpError("expect a ranked tensor for weight, got ")
240360a03c9SJack Frankland           << op.getWeight();
241360a03c9SJack Frankland     }
242afb05823SMehdi Amini     return failure();
243afb05823SMehdi Amini   }
244b2812113SSuraj Sudhir 
245c7676d99SRob Suderman   auto inputEType = inputType.getElementType();
246c7676d99SRob Suderman   auto weightEType = weightType.getElementType();
247c7676d99SRob Suderman 
248c1fa60b4STres Popp   bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
249c1fa60b4STres Popp   bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
250b2812113SSuraj Sudhir 
251b2812113SSuraj Sudhir   // Either both must be quantized or both unquantized.
252afb05823SMehdi Amini   if (inputIsQuant != weightIsQuant) {
253afb05823SMehdi Amini     op.emitOpError(
254afb05823SMehdi Amini         "expect both input and weight to be float or not together, got ")
255afb05823SMehdi Amini         << inputEType << " and " << weightEType;
256b2812113SSuraj Sudhir     return failure();
257afb05823SMehdi Amini   }
258b2812113SSuraj Sudhir 
259b2812113SSuraj Sudhir   // Quantized type must have constructed the quantizationattr, and unquantized
260b2812113SSuraj Sudhir   // types should not have a quantizationattr.
26113448db0SJacques Pienaar   if ((inputIsQuant && !op.getQuantizationInfo()) ||
26213448db0SJacques Pienaar       (!inputIsQuant && op.getQuantizationInfo())) {
263afb05823SMehdi Amini     op.emitOpError("quantizationattr is required for quantized type, and not "
264afb05823SMehdi Amini                    "allowed for float type");
265b2812113SSuraj Sudhir     return failure();
266afb05823SMehdi Amini   }
267a54efdbdSArteen Abrishami   return success();
268a54efdbdSArteen Abrishami }
269a54efdbdSArteen Abrishami 
270a54efdbdSArteen Abrishami LogicalResult tosa::ConstOp::verify() {
271a54efdbdSArteen Abrishami 
272a54efdbdSArteen Abrishami   auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
273a54efdbdSArteen Abrishami   auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
274a54efdbdSArteen Abrishami 
275a54efdbdSArteen Abrishami   if (!attrType || !outputType) {
276a54efdbdSArteen Abrishami     emitOpError("expected tensors for attr/result type");
277a54efdbdSArteen Abrishami     return failure();
278a54efdbdSArteen Abrishami   }
279a54efdbdSArteen Abrishami 
280a54efdbdSArteen Abrishami   if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
281a54efdbdSArteen Abrishami           outputType.getElementType())) {
282a54efdbdSArteen Abrishami     if (result.getStorageType() == attrType.getElementType())
283a54efdbdSArteen Abrishami       return success();
284a54efdbdSArteen Abrishami   }
285a54efdbdSArteen Abrishami 
286a54efdbdSArteen Abrishami   if (attrType.getElementType() != outputType.getElementType()) {
287a54efdbdSArteen Abrishami     emitOpError("expected same attr/result element types");
288a54efdbdSArteen Abrishami     return failure();
289a54efdbdSArteen Abrishami   }
290b2812113SSuraj Sudhir 
291b2812113SSuraj Sudhir   return success();
292b2812113SSuraj Sudhir }
293b2812113SSuraj Sudhir 
294360a03c9SJack Frankland template <typename T>
295360a03c9SJack Frankland static LogicalResult verifyConvOpModes(T op) {
296360a03c9SJack Frankland   auto inputEType =
297360a03c9SJack Frankland       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
298360a03c9SJack Frankland 
299360a03c9SJack Frankland   if (auto quantType =
300360a03c9SJack Frankland           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
301360a03c9SJack Frankland     inputEType = quantType.getStorageType();
302360a03c9SJack Frankland 
303360a03c9SJack Frankland   auto accType = op.getAccType();
304360a03c9SJack Frankland   if (inputEType.isInteger(8) && !accType.isInteger(32))
305360a03c9SJack Frankland     return op.emitOpError("accumulator type for i8 tensor is not i32");
306360a03c9SJack Frankland 
307360a03c9SJack Frankland   if (inputEType.isInteger(16) && !accType.isInteger(48))
308360a03c9SJack Frankland     return op.emitOpError("accumulator type for i16 tensor is not i48");
309360a03c9SJack Frankland 
3107a77f14cSMatthias Springer   if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
311360a03c9SJack Frankland     return op.emitOpError("accumulator type for f8 tensor is not f16");
312360a03c9SJack Frankland 
313360a03c9SJack Frankland   if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
314360a03c9SJack Frankland     return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
315360a03c9SJack Frankland 
316360a03c9SJack Frankland   if (inputEType.isBF16() && !accType.isF32())
317360a03c9SJack Frankland     return op.emitOpError("accumulator type for bf16 tensor is not f32");
318360a03c9SJack Frankland 
319360a03c9SJack Frankland   if (inputEType.isF32() && !accType.isF32())
320360a03c9SJack Frankland     return op.emitOpError("accumulator type for f32 tensor is not f32");
321360a03c9SJack Frankland 
322360a03c9SJack Frankland   return success();
323360a03c9SJack Frankland }
324360a03c9SJack Frankland 
325414709eeSGeorgios Pinitas LogicalResult tosa::ArgMaxOp::verify() {
326414709eeSGeorgios Pinitas   // Ensure output is of 32-bit integer
327414709eeSGeorgios Pinitas   const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
328414709eeSGeorgios Pinitas   if (!resultETy.isIntOrIndex())
329414709eeSGeorgios Pinitas     return emitOpError("result tensor is not of integer type");
330414709eeSGeorgios Pinitas 
331414709eeSGeorgios Pinitas   // Ensure axis is within the tensor rank
332414709eeSGeorgios Pinitas   const auto inputType = llvm::cast<ShapedType>(getInput().getType());
333414709eeSGeorgios Pinitas   const int64_t axis = getAxisAttr().getInt();
334414709eeSGeorgios Pinitas   if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
335414709eeSGeorgios Pinitas     return emitOpError("specified axis is outside the rank of the tensor");
336414709eeSGeorgios Pinitas 
337414709eeSGeorgios Pinitas   return success();
338414709eeSGeorgios Pinitas }
339414709eeSGeorgios Pinitas 
3401be88f5aSRiver Riddle LogicalResult tosa::AvgPool2dOp::verify() {
341b76d8f7dSKai Sasaki   auto inputType = llvm::cast<ShapedType>(getInput().getType());
342b76d8f7dSKai Sasaki 
343b76d8f7dSKai Sasaki   auto inputETy = inputType.getElementType();
344c1fa60b4STres Popp   auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
34595e4b715SRob Suderman 
346c1fa60b4STres Popp   if (auto quantType =
347c1fa60b4STres Popp           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
34895e4b715SRob Suderman     inputETy = quantType.getStorageType();
34995e4b715SRob Suderman 
350c1fa60b4STres Popp   if (auto quantType =
351c1fa60b4STres Popp           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
35295e4b715SRob Suderman     resultETy = quantType.getStorageType();
35395e4b715SRob Suderman 
35426a7f423STatWai Chong   auto accType = getAccType();
355d46a135dSTres Popp   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
35626a7f423STatWai Chong     return emitOpError("accumulator type for integer tensor is not i32");
35726a7f423STatWai Chong 
3582c9ddfc7Sfabrizio-indirli   if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
3592c9ddfc7Sfabrizio-indirli     return emitOpError("accumulator type for f16 tensor is not f16/f32");
3602c9ddfc7Sfabrizio-indirli 
3612c9ddfc7Sfabrizio-indirli   if (inputETy.isBF16() && !accType.isF32())
3622c9ddfc7Sfabrizio-indirli     return emitOpError("accumulator type for bf16 tensor is not f32");
36326a7f423STatWai Chong 
36426a7f423STatWai Chong   if (inputETy.isF32() && !accType.isF32())
36526a7f423STatWai Chong     return emitOpError("accumulator type for f32 tensor is not f32");
36626a7f423STatWai Chong 
3672c9ddfc7Sfabrizio-indirli   if ((inputETy.isF32() && resultETy.isF32()) ||
3682c9ddfc7Sfabrizio-indirli       (inputETy.isF16() && resultETy.isF16()) ||
3692c9ddfc7Sfabrizio-indirli       (inputETy.isBF16() && resultETy.isBF16()) ||
3702c9ddfc7Sfabrizio-indirli       (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
3712c9ddfc7Sfabrizio-indirli       (inputETy.isInteger(16) && resultETy.isInteger(16)))
37295e4b715SRob Suderman     return success();
37395e4b715SRob Suderman 
3741be88f5aSRiver Riddle   return emitOpError("input/output element types are incompatible.");
37595e4b715SRob Suderman }
37695e4b715SRob Suderman 
377dde7b80eSfabrizio-indirli LogicalResult tosa::ClampOp::verify() {
378dde7b80eSfabrizio-indirli   mlir::Type inputETy =
379dde7b80eSfabrizio-indirli       llvm::cast<ShapedType>(getInput().getType()).getElementType();
380a4803d8aSAdrian Kuegel   if (auto quantType =
381a4803d8aSAdrian Kuegel           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
382a4803d8aSAdrian Kuegel     inputETy = quantType.getStorageType();
383a4803d8aSAdrian Kuegel   }
384dde7b80eSfabrizio-indirli   mlir::Type maxFpType = getMaxFpAttr().getType();
385dde7b80eSfabrizio-indirli   mlir::Type minFpType = getMinFpAttr().getType();
386dde7b80eSfabrizio-indirli   mlir::Type outputETy =
387dde7b80eSfabrizio-indirli       llvm::cast<ShapedType>(getOutput().getType()).getElementType();
388a4803d8aSAdrian Kuegel   if (auto quantType =
389a4803d8aSAdrian Kuegel           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
390a4803d8aSAdrian Kuegel     outputETy = quantType.getStorageType();
391a4803d8aSAdrian Kuegel   }
392dde7b80eSfabrizio-indirli   unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
393dde7b80eSfabrizio-indirli 
394dde7b80eSfabrizio-indirli   if (inputETy != outputETy)
395dde7b80eSfabrizio-indirli     return emitOpError("input/output element types are incompatible.");
396dde7b80eSfabrizio-indirli 
397a54efdbdSArteen Abrishami   // If input datatype is float, check that the two min/max_fp attributes
398a54efdbdSArteen Abrishami   // share the same type and that their type is either the same of the input's
399a54efdbdSArteen Abrishami   // datatype, or a float type whose bitwidth > input datatype bitwidth.
400dde7b80eSfabrizio-indirli   if (!inputETy.isInteger(dataTypeBitWidth)) {
401dde7b80eSfabrizio-indirli     if (((maxFpType != minFpType) ||
402dde7b80eSfabrizio-indirli          (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
403dde7b80eSfabrizio-indirli                                        inputETy.getIntOrFloatBitWidth())))
404dde7b80eSfabrizio-indirli       return emitOpError("min/max attributes types are incompatible with "
405dde7b80eSfabrizio-indirli                          "input/output element types.");
406dde7b80eSfabrizio-indirli   }
407dde7b80eSfabrizio-indirli 
408dde7b80eSfabrizio-indirli   return success();
409dde7b80eSfabrizio-indirli }
410dde7b80eSfabrizio-indirli 
411b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
412b2812113SSuraj Sudhir // TOSA Operator Quantization Builders.
413b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
414b2812113SSuraj Sudhir 
415b2812113SSuraj Sudhir /// This builder is called on all convolution operators except TransposeConv,
416b2812113SSuraj Sudhir /// which has specialized output shape semantics. The builder also defines the
417b2812113SSuraj Sudhir /// bitwidth of the output given the bit width of the input & weight content.
418ac3587f2SStella Laurenzo static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
419b2812113SSuraj Sudhir                                      Type outputType, Value input, Value weight,
42011030c7dSAlexander Shaposhnikov                                      Value bias, DenseI64ArrayAttr pad,
42111030c7dSAlexander Shaposhnikov                                      DenseI64ArrayAttr stride,
422360a03c9SJack Frankland                                      DenseI64ArrayAttr dilation,
423360a03c9SJack Frankland                                      TypeAttr accType) {
424b2812113SSuraj Sudhir 
425b2812113SSuraj Sudhir   result.addOperands({input, weight, bias});
426b2812113SSuraj Sudhir   result.addAttribute("pad", pad);
427b2812113SSuraj Sudhir   result.addAttribute("stride", stride);
428b2812113SSuraj Sudhir   result.addAttribute("dilation", dilation);
429360a03c9SJack Frankland   result.addAttribute("acc_type", accType);
430b2812113SSuraj Sudhir 
431b2812113SSuraj Sudhir   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
432b2812113SSuraj Sudhir   if (quantAttr) {
433b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
434b2812113SSuraj Sudhir     result.addTypes(
435b2812113SSuraj Sudhir         buildConvOpResultTypeInfo(builder, outputType, input, weight));
436b2812113SSuraj Sudhir   } else {
437b2812113SSuraj Sudhir     result.addTypes(outputType);
438b2812113SSuraj Sudhir   }
439b2812113SSuraj Sudhir }
440b2812113SSuraj Sudhir 
441a54efdbdSArteen Abrishami /// Handles tosa.transpose_conv2d which has outpad and output shape
442a54efdbdSArteen Abrishami /// attributes.
44311030c7dSAlexander Shaposhnikov static void buildTransConvOpWithQuantInfo(
44411030c7dSAlexander Shaposhnikov     OpBuilder &builder, OperationState &result, Type outputType, Value input,
44511030c7dSAlexander Shaposhnikov     Value weight, Value bias, DenseI64ArrayAttr outpad,
446360a03c9SJack Frankland     DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
447b2812113SSuraj Sudhir   result.addOperands({input, weight, bias});
448b2812113SSuraj Sudhir   result.addAttribute("out_pad", outpad);
449b2812113SSuraj Sudhir   result.addAttribute("stride", stride);
450b2812113SSuraj Sudhir   result.addAttribute("out_shape", outputShape);
451360a03c9SJack Frankland   result.addAttribute("acc_type", accType);
452b2812113SSuraj Sudhir   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
453b2812113SSuraj Sudhir 
454b2812113SSuraj Sudhir   if (quantAttr) {
455b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
456b2812113SSuraj Sudhir     result.addTypes(
457b2812113SSuraj Sudhir         buildConvOpResultTypeInfo(builder, outputType, input, weight));
458b2812113SSuraj Sudhir   } else {
459b2812113SSuraj Sudhir     result.addTypes(outputType);
460b2812113SSuraj Sudhir   }
461b2812113SSuraj Sudhir }
462b2812113SSuraj Sudhir 
463b2812113SSuraj Sudhir /// The tosa.fully_connected op has its own builder as it does not have
464b2812113SSuraj Sudhir /// strides/dilation/padding.
465ac3587f2SStella Laurenzo static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
466b2812113SSuraj Sudhir                                    Type outputType, Value input, Value weight,
467b2812113SSuraj Sudhir                                    Value bias) {
468b2812113SSuraj Sudhir 
469b2812113SSuraj Sudhir   result.addOperands({input, weight, bias});
470b2812113SSuraj Sudhir   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
471b2812113SSuraj Sudhir   if (quantAttr) {
472b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
473b2812113SSuraj Sudhir     result.addTypes(
474b2812113SSuraj Sudhir         buildConvOpResultTypeInfo(builder, outputType, input, weight));
475b2812113SSuraj Sudhir   } else {
476b2812113SSuraj Sudhir     result.addTypes(outputType);
477b2812113SSuraj Sudhir   }
478b2812113SSuraj Sudhir }
479b2812113SSuraj Sudhir 
480a54efdbdSArteen Abrishami /// The tosa.matmul op is also intended to be generated where a
481a54efdbdSArteen Abrishami /// fully_connected op must be constructed where the weight is not a constant.
482a54efdbdSArteen Abrishami /// In this case, the fully_connected op must be expressed using matmul.
483b2812113SSuraj Sudhir /// TODO: Add link to the leglization document explaining this.
484ac3587f2SStella Laurenzo static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
485ac3587f2SStella Laurenzo                                        OperationState &result, Type outputType,
486ac3587f2SStella Laurenzo                                        Value a, Value b) {
487b2812113SSuraj Sudhir   result.addOperands({a, b});
488b2812113SSuraj Sudhir   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
489b2812113SSuraj Sudhir 
490b2812113SSuraj Sudhir   if (quantAttr) {
491b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
492b2812113SSuraj Sudhir 
493c1fa60b4STres Popp     auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
4948662a2f2SRob Suderman     assert(inputType && "Input must be a shaped tensor type!");
495b2812113SSuraj Sudhir 
496c1fa60b4STres Popp     auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
497c1fa60b4STres Popp         inputType.getElementType());
498b2812113SSuraj Sudhir     assert(inputQType && "Tensor must have quantized datatype!");
499b2812113SSuraj Sudhir 
500b2812113SSuraj Sudhir     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
501b2812113SSuraj Sudhir 
502c1fa60b4STres Popp     auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
5038662a2f2SRob Suderman     assert(outputShapedType && "Output must be a shaped type");
504b2812113SSuraj Sudhir 
505b2812113SSuraj Sudhir     IntegerType accElementType;
506b2812113SSuraj Sudhir     if (inputBits == 16)
507b2812113SSuraj Sudhir       accElementType = builder.getIntegerType(48);
508b2812113SSuraj Sudhir     else
509b2812113SSuraj Sudhir       accElementType = builder.getI32Type();
5108662a2f2SRob Suderman     auto accType = outputShapedType.clone(accElementType);
511b2812113SSuraj Sudhir     result.addTypes(accType);
512b2812113SSuraj Sudhir   } else {
513b2812113SSuraj Sudhir     result.addTypes(outputType);
514b2812113SSuraj Sudhir   }
515b2812113SSuraj Sudhir }
516b2812113SSuraj Sudhir 
517a54efdbdSArteen Abrishami /// Both the tosa.avg_pool2d and unary ops use the same
518a54efdbdSArteen Abrishami /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
519a54efdbdSArteen Abrishami /// has additional parameters not part of the unary ops.
52026a7f423STatWai Chong static void
52126a7f423STatWai Chong buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
52226a7f423STatWai Chong                               Type outputType, Value input,
52326a7f423STatWai Chong                               DenseArrayAttr kernel, DenseArrayAttr stride,
524d2f06769SMehdi Amini                               DenseArrayAttr pad, TypeAttr accType) {
525b2812113SSuraj Sudhir   result.addOperands(input);
526b2812113SSuraj Sudhir   result.addAttribute("kernel", kernel);
527b2812113SSuraj Sudhir   result.addAttribute("stride", stride);
528b2812113SSuraj Sudhir   result.addAttribute("pad", pad);
529d2f06769SMehdi Amini   result.addAttribute("acc_type", accType);
530b2812113SSuraj Sudhir   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
531b2812113SSuraj Sudhir   if (quantAttr)
532b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
533b2812113SSuraj Sudhir   result.types.push_back(outputType);
534b2812113SSuraj Sudhir }
535b2812113SSuraj Sudhir 
536b2812113SSuraj Sudhir /// This builder is called on single-parameter unary operators that have scale
537b2812113SSuraj Sudhir /// relationship between their input and output, expressed by the
538b2812113SSuraj Sudhir /// UnaryOpQuantizationAttr.
539ac3587f2SStella Laurenzo static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
540ac3587f2SStella Laurenzo                                       OperationState &result, Type outputType,
541ac3587f2SStella Laurenzo                                       Value input) {
542b2812113SSuraj Sudhir   result.addOperands(input);
543b2812113SSuraj Sudhir   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
544b2812113SSuraj Sudhir   if (quantAttr)
545b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
546b2812113SSuraj Sudhir   result.types.push_back(outputType);
547b2812113SSuraj Sudhir }
548b2812113SSuraj Sudhir 
549b2812113SSuraj Sudhir /// This builder is called on TOSA pad operator that needs to create its own
550b2812113SSuraj Sudhir /// OptionalAttr quantization_attr parameter to scale the padding values
55182568021SSuraj Sudhir /// correctly. No pad_const is interpreted as zero-padding.
552ac3587f2SStella Laurenzo static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
553ac3587f2SStella Laurenzo                                     Type outputType, Value input,
554ac3587f2SStella Laurenzo                                     Value paddings) {
555b2812113SSuraj Sudhir   result.addOperands({input, paddings});
556b2812113SSuraj Sudhir   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
557b2812113SSuraj Sudhir   if (quantAttr)
558b2812113SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
559b2812113SSuraj Sudhir   result.types.push_back(outputType);
560b2812113SSuraj Sudhir }
561b2812113SSuraj Sudhir 
56282568021SSuraj Sudhir /// This builder is called on TOSA pad operator when an explicit pad_const
56382568021SSuraj Sudhir /// value is passed in. It also optionally constructs quantization_attr.
56482568021SSuraj Sudhir static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
56582568021SSuraj Sudhir                                                  OperationState &result,
56682568021SSuraj Sudhir                                                  Type outputType, Value input,
56782568021SSuraj Sudhir                                                  Value paddings,
56802b6fb21SMehdi Amini                                                  Value padConst) {
56902b6fb21SMehdi Amini   result.addOperands({input, paddings, padConst});
57082568021SSuraj Sudhir   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
57182568021SSuraj Sudhir   if (quantAttr)
57282568021SSuraj Sudhir     result.addAttribute("quantization_info", quantAttr);
57382568021SSuraj Sudhir   result.types.push_back(outputType);
57482568021SSuraj Sudhir }
57582568021SSuraj Sudhir 
576b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
5778dea784bSRob Suderman // TOSA Operator Return Type Inference.
5788dea784bSRob Suderman //===----------------------------------------------------------------------===//
5798dea784bSRob Suderman 
580b73e8325Snot-jenni static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
581b73e8325Snot-jenni                                            SmallVector<int64_t> &outShape) {
582b73e8325Snot-jenni   int64_t outRank = 0;
583b73e8325Snot-jenni   for (int i = 0, e = operands.size(); i != e; ++i) {
584b73e8325Snot-jenni     auto shape = operands.getShape(i);
585b73e8325Snot-jenni     if (!shape.hasRank()) {
586a54efdbdSArteen Abrishami       // TODO(jennik): Update function to have better case handling for
587a54efdbdSArteen Abrishami       // invalid operands and for ranked tensors.
588b73e8325Snot-jenni       return failure();
589b73e8325Snot-jenni     }
590b73e8325Snot-jenni     outRank = std::max<int64_t>(outRank, shape.getRank());
591b73e8325Snot-jenni   }
592b73e8325Snot-jenni 
593b73e8325Snot-jenni   outShape.resize(outRank, 1);
594b73e8325Snot-jenni 
595b73e8325Snot-jenni   for (int i = 0, e = operands.size(); i != e; ++i) {
596b73e8325Snot-jenni     auto shape = operands.getShape(i);
597b73e8325Snot-jenni     auto rankDiff = outShape.size() - shape.getRank();
598b73e8325Snot-jenni 
599b73e8325Snot-jenni     for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
600b73e8325Snot-jenni       auto dim1 = outShape[i + rankDiff];
601b73e8325Snot-jenni       auto dim2 = shape.getDimSize(i);
602b73e8325Snot-jenni       auto resolvedDim = dim1;
603b73e8325Snot-jenni 
604b73e8325Snot-jenni       if (dim1 == 1) {
605b73e8325Snot-jenni         resolvedDim = dim2;
606b73e8325Snot-jenni       } else if (dim2 == 1) {
607b73e8325Snot-jenni         resolvedDim = dim1;
608b73e8325Snot-jenni       } else if (dim1 != dim2) {
609b73e8325Snot-jenni         return failure();
610b73e8325Snot-jenni       }
611b73e8325Snot-jenni       outShape[i + rankDiff] = resolvedDim;
612b73e8325Snot-jenni     }
613b73e8325Snot-jenni   }
614b73e8325Snot-jenni 
615b73e8325Snot-jenni   return success();
616b73e8325Snot-jenni }
617b73e8325Snot-jenni 
6185a4e7760SRob Suderman LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
61922426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
620057fc8e7SAmanda Tang     ArgMaxOp::Adaptor adaptor,
6215a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
622057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
623057fc8e7SAmanda Tang   IntegerAttr axis = adaptor.getProperties().axis;
6245a4e7760SRob Suderman   int32_t axisVal = axis.getValue().getSExtValue();
6255a4e7760SRob Suderman 
62609349303SJacques Pienaar   if (!inputShape.hasRank()) {
6275a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents());
6285a4e7760SRob Suderman     return success();
6295a4e7760SRob Suderman   }
6305a4e7760SRob Suderman 
6315a4e7760SRob Suderman   SmallVector<int64_t> outShape;
63209349303SJacques Pienaar   outShape.reserve(inputShape.getRank() - 1);
63309349303SJacques Pienaar   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
6345a4e7760SRob Suderman     if (i == axisVal)
6355a4e7760SRob Suderman       continue;
63609349303SJacques Pienaar     outShape.push_back(inputShape.getDimSize(i));
6375a4e7760SRob Suderman   }
6385a4e7760SRob Suderman 
6395a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
6405a4e7760SRob Suderman   return success();
6415a4e7760SRob Suderman }
6425a4e7760SRob Suderman 
64394f255c2SLuke Hutton LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
64494f255c2SLuke Hutton     MLIRContext *context, ::std::optional<Location> location,
645057fc8e7SAmanda Tang     RFFT2dOp::Adaptor adaptor,
64694f255c2SLuke Hutton     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
647057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
64894f255c2SLuke Hutton 
64994f255c2SLuke Hutton   if (!inputShape.hasRank())
65094f255c2SLuke Hutton     return failure();
65194f255c2SLuke Hutton 
65294f255c2SLuke Hutton   llvm::SmallVector<int64_t> outputShape;
65394f255c2SLuke Hutton   outputShape.resize(3, ShapedType::kDynamic);
65494f255c2SLuke Hutton   outputShape[0] = inputShape.getDimSize(0);
65594f255c2SLuke Hutton   outputShape[1] = inputShape.getDimSize(1);
65694f255c2SLuke Hutton   int64_t inWidth = inputShape.getDimSize(2);
65794f255c2SLuke Hutton 
65894f255c2SLuke Hutton   // Note that we can support this calculation symbolically
65994f255c2SLuke Hutton   // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
66094f255c2SLuke Hutton   if (inWidth != ShapedType::kDynamic)
66194f255c2SLuke Hutton     outputShape[2] = inWidth / 2 + 1;
66294f255c2SLuke Hutton 
66394f255c2SLuke Hutton   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
66494f255c2SLuke Hutton   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
66531293886SLuke Hutton 
66631293886SLuke Hutton   return success();
66731293886SLuke Hutton }
66831293886SLuke Hutton 
66931293886SLuke Hutton LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
67031293886SLuke Hutton     MLIRContext *context, ::std::optional<Location> location,
671057fc8e7SAmanda Tang     FFT2dOp::Adaptor adaptor,
67231293886SLuke Hutton     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
673057fc8e7SAmanda Tang   inferredReturnShapes.push_back(
674057fc8e7SAmanda Tang       ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
675057fc8e7SAmanda Tang   inferredReturnShapes.push_back(
676057fc8e7SAmanda Tang       ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
67794f255c2SLuke Hutton   return success();
67894f255c2SLuke Hutton }
67994f255c2SLuke Hutton 
6805a4e7760SRob Suderman LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
68122426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
682057fc8e7SAmanda Tang     ConcatOp::Adaptor adaptor,
6835a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
6845a4e7760SRob Suderman   // Infer all dimension sizes by reducing based on inputs.
685057fc8e7SAmanda Tang   const Properties &prop = adaptor.getProperties();
686057fc8e7SAmanda Tang   int32_t axis = prop.axis.getValue().getSExtValue();
6875a4e7760SRob Suderman   llvm::SmallVector<int64_t> outputShape;
6885a4e7760SRob Suderman   bool hasRankedInput = false;
689057fc8e7SAmanda Tang   for (auto operand : adaptor.getOperands()) {
690057fc8e7SAmanda Tang     ShapeAdaptor operandShape(operand.getType());
69109349303SJacques Pienaar     if (!operandShape.hasRank())
6925a4e7760SRob Suderman       continue;
6935a4e7760SRob Suderman 
6945a4e7760SRob Suderman     // Copy the Operand's rank.
6955a4e7760SRob Suderman     if (!hasRankedInput)
696399638f9SAliia Khasanova       outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
6975a4e7760SRob Suderman 
6985a4e7760SRob Suderman     // Copy shapes until the dim is non-dynamic.
69909349303SJacques Pienaar     for (int i = 0, s = operandShape.getRank(); i < s; i++) {
70009349303SJacques Pienaar       if (i == axis || operandShape.isDynamicDim(i))
7015a4e7760SRob Suderman         continue;
702399638f9SAliia Khasanova       if (outputShape[i] == ShapedType::kDynamic)
70309349303SJacques Pienaar         outputShape[i] = operandShape.getDimSize(i);
70409349303SJacques Pienaar       if (outputShape[i] != operandShape.getDimSize(i))
705fd004a49SMaya Amrami         return emitOptionalError(location,
706fd004a49SMaya Amrami                                  "Cannot concat tensors with different sizes"
707fd004a49SMaya Amrami                                  " on the non-axis dimension ",
708fd004a49SMaya Amrami                                  i);
7095a4e7760SRob Suderman     }
7105a4e7760SRob Suderman 
7115a4e7760SRob Suderman     hasRankedInput = true;
7125a4e7760SRob Suderman   }
713c1fa60b4STres Popp   Type inputType =
714057fc8e7SAmanda Tang       llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
7155a4e7760SRob Suderman   if (!hasRankedInput) {
716fd004a49SMaya Amrami     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
7175a4e7760SRob Suderman     return success();
7185a4e7760SRob Suderman   }
7195a4e7760SRob Suderman 
7205a4e7760SRob Suderman   // Determine the dimension size along the concatenation axis.
721fb4cedccSAliia Khasanova   int64_t concatDimSize = 0;
722057fc8e7SAmanda Tang   for (auto operand : adaptor.getOperands()) {
723057fc8e7SAmanda Tang     ShapeAdaptor operandShape(operand.getType());
7245a4e7760SRob Suderman 
7255a4e7760SRob Suderman     // We need to know the length of the concatenation axis of all inputs to
7265a4e7760SRob Suderman     // determine the dimension size of the output shape.
72709349303SJacques Pienaar     if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
728399638f9SAliia Khasanova       concatDimSize = ShapedType::kDynamic;
7295a4e7760SRob Suderman       break;
7305a4e7760SRob Suderman     }
7315a4e7760SRob Suderman 
73209349303SJacques Pienaar     concatDimSize += operandShape.getDimSize(axis);
7335a4e7760SRob Suderman   }
7345a4e7760SRob Suderman 
7355a4e7760SRob Suderman   outputShape[axis] = concatDimSize;
7365a4e7760SRob Suderman 
737fd004a49SMaya Amrami   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
7385a4e7760SRob Suderman   return success();
7395a4e7760SRob Suderman }
7405a4e7760SRob Suderman 
741b73e8325Snot-jenni LogicalResult tosa::EqualOp::inferReturnTypeComponents(
74222426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
7435e118f93SMehdi Amini     ValueShapeRange operands, DictionaryAttr attributes,
7445e118f93SMehdi Amini     OpaqueProperties properties, RegionRange regions,
745b73e8325Snot-jenni     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
74619109a27SSpenser Bauman   auto elementType = IntegerType::get(context, /*width=*/1);
74719109a27SSpenser Bauman 
748b73e8325Snot-jenni   llvm::SmallVector<int64_t> outShape;
749b73e8325Snot-jenni   if (resolveBroadcastShape(operands, outShape).failed()) {
75019109a27SSpenser Bauman     inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
751b73e8325Snot-jenni     return success();
752b73e8325Snot-jenni   }
753b73e8325Snot-jenni 
75419109a27SSpenser Bauman   inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
755b73e8325Snot-jenni   return success();
756b73e8325Snot-jenni }
757b73e8325Snot-jenni 
758b73e8325Snot-jenni bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
759b73e8325Snot-jenni   if (l.size() != r.size() || l.size() != 1)
760b73e8325Snot-jenni     return false;
761b73e8325Snot-jenni   return succeeded(verifyCompatibleShape(l[0], r[0]));
762b73e8325Snot-jenni }
763b73e8325Snot-jenni 
7645a4e7760SRob Suderman LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
76522426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
766057fc8e7SAmanda Tang     FullyConnectedOp::Adaptor adaptor,
7675a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
768057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
769057fc8e7SAmanda Tang   ShapeAdaptor weightShape(adaptor.getWeight().getType());
770057fc8e7SAmanda Tang   ShapeAdaptor biasShape(adaptor.getBias().getType());
7715a4e7760SRob Suderman 
7725a4e7760SRob Suderman   // All shapes are dynamic.
7735a4e7760SRob Suderman   SmallVector<int64_t> outShape;
774399638f9SAliia Khasanova   outShape.resize(2, ShapedType::kDynamic);
7755a4e7760SRob Suderman 
77609349303SJacques Pienaar   if (inputShape.hasRank()) {
77709349303SJacques Pienaar     outShape[0] = inputShape.getDimSize(0);
7785a4e7760SRob Suderman   }
7795a4e7760SRob Suderman 
78009349303SJacques Pienaar   if (weightShape.hasRank()) {
78109349303SJacques Pienaar     outShape[1] = weightShape.getDimSize(0);
7825a4e7760SRob Suderman   }
7835a4e7760SRob Suderman 
78409349303SJacques Pienaar   if (biasShape.hasRank()) {
78522426110SRamkumar Ramachandra     outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
786143edecaSRob Suderman                                                       : outShape[1];
7875a4e7760SRob Suderman   }
7885a4e7760SRob Suderman 
7895a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
7905a4e7760SRob Suderman   return success();
7915a4e7760SRob Suderman }
7925a4e7760SRob Suderman 
7931be88f5aSRiver Riddle LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
7941be88f5aSRiver Riddle 
7955a4e7760SRob Suderman LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
79622426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
797057fc8e7SAmanda Tang     MatMulOp::Adaptor adaptor,
7985a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
799057fc8e7SAmanda Tang   ShapeAdaptor lhsShape(adaptor.getA().getType());
800057fc8e7SAmanda Tang   ShapeAdaptor rhsShape(adaptor.getB().getType());
8015a4e7760SRob Suderman 
8025a4e7760SRob Suderman   // All shapes are dynamic.
8035a4e7760SRob Suderman   SmallVector<int64_t> outShape;
804399638f9SAliia Khasanova   outShape.resize(3, ShapedType::kDynamic);
8055a4e7760SRob Suderman 
80609349303SJacques Pienaar   if (lhsShape.hasRank()) {
80709349303SJacques Pienaar     outShape[0] = lhsShape.getDimSize(0);
80809349303SJacques Pienaar     outShape[1] = lhsShape.getDimSize(1);
8095a4e7760SRob Suderman   }
8105a4e7760SRob Suderman 
81109349303SJacques Pienaar   if (rhsShape.hasRank()) {
81222426110SRamkumar Ramachandra     outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
813143edecaSRob Suderman                                                       : outShape[0];
81409349303SJacques Pienaar     outShape[2] = rhsShape.getDimSize(2);
8155a4e7760SRob Suderman   }
8165a4e7760SRob Suderman 
8175a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
8185a4e7760SRob Suderman   return success();
8195a4e7760SRob Suderman }
8205a4e7760SRob Suderman 
8215a4e7760SRob Suderman LogicalResult tosa::PadOp::inferReturnTypeComponents(
82222426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
823057fc8e7SAmanda Tang     PadOp::Adaptor adaptor,
8245a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
825057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput1().getType());
8267e622b61SJerry-Ge   auto paddingRank =
8277e622b61SJerry-Ge       cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
8285a4e7760SRob Suderman   SmallVector<int64_t> outputShape;
8295a4e7760SRob Suderman 
8307e622b61SJerry-Ge   // If the input rank is unknown, we can infer the output rank using the
8317e622b61SJerry-Ge   // padding shape's rank divided by 2.
83209349303SJacques Pienaar   if (!inputShape.hasRank()) {
8337e622b61SJerry-Ge     outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
8345a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
8355a4e7760SRob Suderman     return success();
8365a4e7760SRob Suderman   }
8375a4e7760SRob Suderman 
8385a4e7760SRob Suderman   SmallVector<int64_t> paddingValues;
8397e622b61SJerry-Ge   // If the paddings value is not a constant, all dimensions must be dynamic.
8407e622b61SJerry-Ge   if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
8417e622b61SJerry-Ge                                 paddingValues)) {
8427e622b61SJerry-Ge     outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
8437e622b61SJerry-Ge     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
8447e622b61SJerry-Ge     return success();
8455a4e7760SRob Suderman   }
8465a4e7760SRob Suderman 
84709349303SJacques Pienaar   outputShape.reserve(inputShape.getRank());
84809349303SJacques Pienaar   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
84909349303SJacques Pienaar     if (inputShape.isDynamicDim(i)) {
850399638f9SAliia Khasanova       outputShape.push_back(ShapedType::kDynamic);
8515a4e7760SRob Suderman       continue;
8525a4e7760SRob Suderman     }
8537e622b61SJerry-Ge     auto padFront = paddingValues[i * 2];
8547e622b61SJerry-Ge     auto padBack = paddingValues[i * 2 + 1];
8557e622b61SJerry-Ge     if (padFront < 0 || padBack < 0) {
8567e622b61SJerry-Ge       // if either padding for dim i is -1, output dim is unknown
8577e622b61SJerry-Ge       outputShape.push_back(ShapedType::kDynamic);
8587e622b61SJerry-Ge       continue;
8597e622b61SJerry-Ge     }
8605a4e7760SRob Suderman 
8617e622b61SJerry-Ge     outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
8625a4e7760SRob Suderman   }
8635a4e7760SRob Suderman 
8645a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
8655a4e7760SRob Suderman   return success();
8665a4e7760SRob Suderman }
8675a4e7760SRob Suderman 
86837263b6cSLongsheng Mou LogicalResult tosa::PadOp::verify() {
86937263b6cSLongsheng Mou   RankedTensorType inputType = getInput1().getType();
87037263b6cSLongsheng Mou   RankedTensorType outputType = getOutput().getType();
8717e622b61SJerry-Ge   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
87237263b6cSLongsheng Mou 
87337263b6cSLongsheng Mou   if (inputType.getRank() != outputType.getRank())
87437263b6cSLongsheng Mou     return emitOpError() << "expect same input and output tensor rank.";
87537263b6cSLongsheng Mou 
8767e622b61SJerry-Ge   if (paddingRank != inputType.getRank() * 2)
877c1d01b2fSLongsheng Mou     return emitOpError() << "expected padding tensor dim 0 to have size "
878c1d01b2fSLongsheng Mou                          << inputType.getRank() * 2
8797e622b61SJerry-Ge                          << " (2*rank(shape1)) but got size " << paddingRank;
88037263b6cSLongsheng Mou 
88137263b6cSLongsheng Mou   return success();
88237263b6cSLongsheng Mou }
88337263b6cSLongsheng Mou 
884fa5a607dSRob Suderman static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
885fa5a607dSRob Suderman   return to_vector(llvm::map_range(shape, [](int64_t dim) {
886399638f9SAliia Khasanova     return dim == -1 ? ShapedType::kDynamic : dim;
887fa5a607dSRob Suderman   }));
888fa5a607dSRob Suderman }
889fa5a607dSRob Suderman 
8905a4e7760SRob Suderman LogicalResult tosa::SliceOp::inferReturnTypeComponents(
89122426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
892057fc8e7SAmanda Tang     SliceOp::Adaptor adaptor,
8935a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
894*956c0707SJerry-Ge 
895*956c0707SJerry-Ge   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
896*956c0707SJerry-Ge   SmallVector<int64_t> start;
897*956c0707SJerry-Ge   SmallVector<int64_t> size;
898*956c0707SJerry-Ge 
899*956c0707SJerry-Ge   if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
900*956c0707SJerry-Ge       !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
901*956c0707SJerry-Ge     auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
902*956c0707SJerry-Ge     SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
903*956c0707SJerry-Ge     inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
904*956c0707SJerry-Ge     return success();
905*956c0707SJerry-Ge   }
9067986e0caSTai Ly 
9077986e0caSTai Ly   // if size[i] is -1, all remaining elements in dimension i are included
9087986e0caSTai Ly   // in the slice, similar to TF.
9097986e0caSTai Ly   ShapeAdaptor inputShape(adaptor.getInput1().getType());
9107986e0caSTai Ly   // initialize outputShape to all unknown
9117986e0caSTai Ly   SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
9127986e0caSTai Ly   if (inputShape.hasRank()) {
9137986e0caSTai Ly     for (size_t i = 0; i < size.size(); i++) {
9147986e0caSTai Ly       if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
9157986e0caSTai Ly           (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
9167986e0caSTai Ly            start[i] < inputShape.getDimSize(i))) {
9177986e0caSTai Ly         // size[i] is not 0 and not < -1, and start[i] is in valid range
9187986e0caSTai Ly         if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
9197986e0caSTai Ly           // input shape has unknown dim[i] - only valid if size[i] > 0
9207986e0caSTai Ly           if (size[i] > 0) {
9217986e0caSTai Ly             outputShape[i] = size[i];
9227986e0caSTai Ly           }
9237986e0caSTai Ly         } else {
9247986e0caSTai Ly           // input shape has known dim[i]
9257986e0caSTai Ly           if (size[i] == -1) {
9267986e0caSTai Ly             outputShape[i] = inputShape.getDimSize(i) - start[i];
9277986e0caSTai Ly           } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
9287986e0caSTai Ly             // start[i] + size[i] is within bound of input shape's dim[i]
9297986e0caSTai Ly             outputShape[i] = size[i];
9307986e0caSTai Ly           }
9317986e0caSTai Ly         }
9327986e0caSTai Ly       }
9337986e0caSTai Ly     }
9347986e0caSTai Ly   } else {
9357986e0caSTai Ly     outputShape = convertToMlirShape(size);
9367986e0caSTai Ly   }
9377986e0caSTai Ly   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
9385a4e7760SRob Suderman   return success();
9395a4e7760SRob Suderman }
9405a4e7760SRob Suderman 
941ab7e8b76Slong.chen LogicalResult tosa::SliceOp::verify() {
942c6876b4eSJerry-Ge   auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
943ab7e8b76Slong.chen   if (!inputType)
944ab7e8b76Slong.chen     return success();
945ab7e8b76Slong.chen 
946*956c0707SJerry-Ge   auto startShapeRank =
947*956c0707SJerry-Ge       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
948*956c0707SJerry-Ge   if (inputType.getRank() != startShapeRank)
949ab7e8b76Slong.chen     return emitOpError(
950ab7e8b76Slong.chen         "length of start attribute is not equal rank of input shape");
951ab7e8b76Slong.chen 
952*956c0707SJerry-Ge   auto sizeShapeRank =
953*956c0707SJerry-Ge       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
954*956c0707SJerry-Ge   if (inputType.getRank() != sizeShapeRank)
955ab7e8b76Slong.chen     return emitOpError(
956ab7e8b76Slong.chen         "length of size attribute is not equal rank of input shape");
957ab7e8b76Slong.chen 
958ab7e8b76Slong.chen   return success();
959ab7e8b76Slong.chen }
960ab7e8b76Slong.chen 
961519eef3bSLongsheng Mou LogicalResult tosa::MulOp::verify() {
962a58e774fSJack Frankland   auto resElemType = getElementTypeOrSelf(getOutput());
963a58e774fSJack Frankland 
964a58e774fSJack Frankland   // Verify if the element type among operands and result match tosa
965a58e774fSJack Frankland   // specification.
966a58e774fSJack Frankland   if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
967a58e774fSJack Frankland     IntegerType lhsIntType =
968a58e774fSJack Frankland         cast<IntegerType>(getElementTypeOrSelf(getInput1()));
969a58e774fSJack Frankland     IntegerType rhsIntType =
970a58e774fSJack Frankland         cast<IntegerType>(getElementTypeOrSelf(getInput2()));
971a58e774fSJack Frankland     if (lhsIntType != rhsIntType)
972a58e774fSJack Frankland       return emitOpError("requires the same element type for all operands");
973a58e774fSJack Frankland 
974a58e774fSJack Frankland     // Though the spec requires the element type of result to be i32, a more
975a58e774fSJack Frankland     // relaxed way is provided at dialect level for easier cooperating with
976a58e774fSJack Frankland     // other dialects.
977a58e774fSJack Frankland     if (lhsIntType.getWidth() > resIntType.getWidth())
978a58e774fSJack Frankland       return emitOpError("invalid data type size for operands or result");
979a58e774fSJack Frankland 
980a58e774fSJack Frankland   } else {
981a58e774fSJack Frankland     // For other supported type, the spec requires requires the same element
982a58e774fSJack Frankland     // type for all operands (excludes `shift` operand) and results.
983a58e774fSJack Frankland     for (int i = 0; i < 2; ++i) {
984a58e774fSJack Frankland       if (getElementTypeOrSelf(getOperand(i)) != resElemType)
985a58e774fSJack Frankland         return emitOpError(
986a58e774fSJack Frankland             "requires the same element type for all operands and results");
987a58e774fSJack Frankland     }
988a58e774fSJack Frankland   }
989a58e774fSJack Frankland 
990a58e774fSJack Frankland   // Verify the op has same ranks for all main operands (excludes extra operands
991a58e774fSJack Frankland   // such as shift of mul op, so this is the only difference with the built-in
992a58e774fSJack Frankland   // `SameOperandsAndResultRank` trait) and results types, if known.
993a58e774fSJack Frankland 
994a58e774fSJack Frankland   // delegate function that returns true if type is a shaped type with known
995a58e774fSJack Frankland   // rank
996a58e774fSJack Frankland   auto hasRank = [](const Type type) {
997a58e774fSJack Frankland     if (auto shaped_type = dyn_cast<ShapedType>(type))
998a58e774fSJack Frankland       return shaped_type.hasRank();
999a58e774fSJack Frankland 
1000a58e774fSJack Frankland     return false;
1001a58e774fSJack Frankland   };
1002a58e774fSJack Frankland 
1003a58e774fSJack Frankland   auto rankedOperandTypes =
1004a58e774fSJack Frankland       llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1005a58e774fSJack Frankland 
1006a58e774fSJack Frankland   auto rankedResultTypes =
1007a58e774fSJack Frankland       llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1008a58e774fSJack Frankland 
1009a58e774fSJack Frankland   // If all operands and results are unranked, then no further verification.
1010a58e774fSJack Frankland   if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1011a58e774fSJack Frankland     return success();
1012a58e774fSJack Frankland 
1013a58e774fSJack Frankland   // delegate function that returns rank of shaped type with known rank
1014a58e774fSJack Frankland   auto getRank = [](const Type type) {
1015a58e774fSJack Frankland     return cast<ShapedType>(type).getRank();
1016a58e774fSJack Frankland   };
1017a58e774fSJack Frankland 
1018a58e774fSJack Frankland   auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1019a58e774fSJack Frankland                                           : getRank(*rankedResultTypes.begin());
1020a58e774fSJack Frankland 
1021a58e774fSJack Frankland   for (size_t i = 0; i < 2; ++i) {
1022a58e774fSJack Frankland     if (rank != getRank(rankedOperandTypes[i])) {
1023a58e774fSJack Frankland       return emitOpError("operands don't have matching ranks");
1024a58e774fSJack Frankland     }
1025a58e774fSJack Frankland   }
1026a58e774fSJack Frankland 
1027a58e774fSJack Frankland   for (const auto type : rankedResultTypes) {
1028a58e774fSJack Frankland     if (rank != getRank(type)) {
1029a58e774fSJack Frankland       return emitOpError("result type has different rank than operands");
1030a58e774fSJack Frankland     }
1031a58e774fSJack Frankland   }
1032519eef3bSLongsheng Mou 
1033519eef3bSLongsheng Mou   return success();
1034519eef3bSLongsheng Mou }
1035519eef3bSLongsheng Mou 
10365a4e7760SRob Suderman LogicalResult tosa::TableOp::inferReturnTypeComponents(
103722426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1038057fc8e7SAmanda Tang     TableOp::Adaptor adaptor,
10395a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1040c6876b4eSJerry-Ge   ShapeAdaptor inputShape(adaptor.getInput1().getType());
10415a4e7760SRob Suderman 
104209349303SJacques Pienaar   if (!inputShape.hasRank()) {
10435a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents());
10445a4e7760SRob Suderman     return success();
10455a4e7760SRob Suderman   }
10465a4e7760SRob Suderman 
104709349303SJacques Pienaar   inferredReturnShapes.resize(1);
104809349303SJacques Pienaar   inputShape.getDims(inferredReturnShapes[0]);
10495a4e7760SRob Suderman   return success();
10505a4e7760SRob Suderman }
10515a4e7760SRob Suderman 
10521e347062SLongsheng Mou LogicalResult tosa::TableOp::verify() {
1053c6876b4eSJerry-Ge   TensorType inputType = getInput1().getType();
10541e347062SLongsheng Mou   TensorType outputType = getOutput().getType();
10551e347062SLongsheng Mou 
10561e347062SLongsheng Mou   if (inputType.hasRank() && outputType.hasRank() &&
10571e347062SLongsheng Mou       inputType.getRank() != outputType.getRank())
10581e347062SLongsheng Mou     return emitOpError()
10591e347062SLongsheng Mou            << "expected input tensor rank to equal result tensor rank";
10601e347062SLongsheng Mou 
10611e347062SLongsheng Mou   auto inputDims = inputType.getShape();
10621e347062SLongsheng Mou   auto outputDims = outputType.getShape();
10631e347062SLongsheng Mou   for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
10641e347062SLongsheng Mou     int64_t dim = it.index();
10651e347062SLongsheng Mou     auto [inputDim, outputDim] = it.value();
10661e347062SLongsheng Mou     if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
10671e347062SLongsheng Mou       return emitOpError() << "dim(result, " << dim << ") = " << outputDim
10681e347062SLongsheng Mou                            << " doesn't match dim(input, " << dim
10691e347062SLongsheng Mou                            << ") = " << inputDim;
10701e347062SLongsheng Mou     }
10711e347062SLongsheng Mou   }
10721e347062SLongsheng Mou   return success();
10731e347062SLongsheng Mou }
10741e347062SLongsheng Mou 
1075f09db6a3SJerry-Ge LogicalResult
1076f09db6a3SJerry-Ge tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1077f09db6a3SJerry-Ge   // Multiples must be constants.
1078f09db6a3SJerry-Ge   DenseIntElementsAttr multiplesAttr;
1079f09db6a3SJerry-Ge   if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1080f09db6a3SJerry-Ge     return failure();
1081f09db6a3SJerry-Ge   multiples = llvm::to_vector(
1082f09db6a3SJerry-Ge       llvm::map_range(multiplesAttr.getValues<APInt>(),
1083f09db6a3SJerry-Ge                       [](const APInt &val) { return val.getSExtValue(); }));
1084f09db6a3SJerry-Ge   return success();
1085f09db6a3SJerry-Ge }
1086f09db6a3SJerry-Ge 
10875a4e7760SRob Suderman LogicalResult tosa::TileOp::inferReturnTypeComponents(
108822426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1089057fc8e7SAmanda Tang     TileOp::Adaptor adaptor,
10905a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1091f09db6a3SJerry-Ge   DenseIntElementsAttr multiplesAttr;
1092f09db6a3SJerry-Ge   if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1093f09db6a3SJerry-Ge     return failure();
1094f09db6a3SJerry-Ge 
1095f09db6a3SJerry-Ge   SmallVector<int64_t> multiples = llvm::to_vector(
1096f09db6a3SJerry-Ge       llvm::map_range(multiplesAttr.getValues<APInt>(),
1097f09db6a3SJerry-Ge                       [](const APInt &val) { return val.getSExtValue(); }));
1098f09db6a3SJerry-Ge 
1099057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput1().getType());
11005a4e7760SRob Suderman   SmallVector<int64_t> outputShape;
110109349303SJacques Pienaar   if (!inputShape.hasRank()) {
1102399638f9SAliia Khasanova     outputShape.resize(multiples.size(), ShapedType::kDynamic);
11035a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
11045a4e7760SRob Suderman     return success();
11050e6c679cSKazu Hirata   } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1106b6d67af2SFelix Schneider     return failure();
11075a4e7760SRob Suderman 
11085a4e7760SRob Suderman   // Any non dynamic dimension can be multiplied to a known size.
11095a4e7760SRob Suderman   outputShape.reserve(multiples.size());
111009349303SJacques Pienaar   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1111fb4cedccSAliia Khasanova     int64_t dim = inputShape.getDimSize(i);
1112399638f9SAliia Khasanova     if (dim != ShapedType::kDynamic)
11139e1a3441SAlexander Shaposhnikov       dim *= multiples[i];
11145a4e7760SRob Suderman     outputShape.push_back(dim);
11155a4e7760SRob Suderman   }
11165a4e7760SRob Suderman 
11175a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
11185a4e7760SRob Suderman   return success();
11195a4e7760SRob Suderman }
11205a4e7760SRob Suderman 
1121b6d67af2SFelix Schneider LogicalResult tosa::TileOp::verify() {
1122b6d67af2SFelix Schneider   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1123b6d67af2SFelix Schneider   ShapedType outputType = llvm::cast<ShapedType>(getType());
1124f09db6a3SJerry-Ge 
1125f09db6a3SJerry-Ge   shapeType multiplesType =
1126f09db6a3SJerry-Ge       llvm::cast<tosa::shapeType>(getMultiples().getType());
1127f09db6a3SJerry-Ge 
1128f09db6a3SJerry-Ge   auto multiplesRank = multiplesType.getRank();
1129b6d67af2SFelix Schneider 
1130b6d67af2SFelix Schneider   if (inputType.hasRank()) {
1131f09db6a3SJerry-Ge     if (inputType.getRank() != multiplesRank)
1132f09db6a3SJerry-Ge       return emitOpError("expect 'multiples' to have rank ")
1133f09db6a3SJerry-Ge              << inputType.getRank() << " but got " << multiplesRank << ".";
1134b6d67af2SFelix Schneider     if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1135b6d67af2SFelix Schneider       return emitOpError("expect same input and output tensor rank.");
1136f09db6a3SJerry-Ge   } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1137b6d67af2SFelix Schneider     return emitOpError("expect 'multiples' array to have length ")
1138f09db6a3SJerry-Ge            << outputType.getRank() << " but got " << multiplesRank << ".";
1139b6d67af2SFelix Schneider 
1140f09db6a3SJerry-Ge   SmallVector<int64_t> multiples;
1141f09db6a3SJerry-Ge   if (getConstantMultiples(multiples).succeeded() &&
1142f09db6a3SJerry-Ge       llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1143c8568f09SLongsheng Mou     return emitOpError(
1144c8568f09SLongsheng Mou         "expect element of 'multiples' to be positive integer or -1.");
1145c8568f09SLongsheng Mou 
1146b6d67af2SFelix Schneider   return success();
1147b6d67af2SFelix Schneider }
1148b6d67af2SFelix Schneider 
11492dd396c1SAviad Cohen bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
11502dd396c1SAviad Cohen   if (l.size() != r.size() || l.size() != 1)
11512dd396c1SAviad Cohen     return false;
11522dd396c1SAviad Cohen   return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
11532dd396c1SAviad Cohen }
11542dd396c1SAviad Cohen 
11558dea784bSRob Suderman LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
115622426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1157057fc8e7SAmanda Tang     ReshapeOp::Adaptor adaptor,
11588dea784bSRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1159057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput1().getType());
1160057fc8e7SAmanda Tang   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
11619e1a3441SAlexander Shaposhnikov   llvm::SmallVector<int64_t> newShapeValue =
11629e1a3441SAlexander Shaposhnikov       convertToMlirShape(adaptor.getNewShape());
11638dea784bSRob Suderman 
11648dea784bSRob Suderman   // We cannot infer from the total number of elements so we must take the
11658dea784bSRob Suderman   // shape attribute as exact.
116609349303SJacques Pienaar   if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
11672dd396c1SAviad Cohen     inferredReturnShapes.push_back(
11682dd396c1SAviad Cohen         ShapedTypeComponents(newShapeValue, inputType));
11698dea784bSRob Suderman     return success();
11708dea784bSRob Suderman   }
11718dea784bSRob Suderman 
11728dea784bSRob Suderman   // Determine the number of elements covered by the slice of all static
11738dea784bSRob Suderman   // dimensions. This allows us to infer the length of the remaining dynamic
11748dea784bSRob Suderman   // dimension.
117509349303SJacques Pienaar   int64_t numElements = inputShape.getNumElements();
11768dea784bSRob Suderman   int64_t staticMul = 1;
11778dea784bSRob Suderman   for (auto val : newShapeValue) {
1178fb4cedccSAliia Khasanova     if (!ShapedType::isDynamic(val)) {
11798dea784bSRob Suderman       staticMul *= val;
11808dea784bSRob Suderman     }
11818dea784bSRob Suderman   }
11828dea784bSRob Suderman 
11838dea784bSRob Suderman   // Determine the length of the dynamic dimension.
11848dea784bSRob Suderman   for (auto &val : newShapeValue) {
1185fb4cedccSAliia Khasanova     if (ShapedType::isDynamic(val))
11868dea784bSRob Suderman       val = numElements / staticMul;
11878dea784bSRob Suderman   }
11888dea784bSRob Suderman 
11892dd396c1SAviad Cohen   inferredReturnShapes.push_back(
11902dd396c1SAviad Cohen       ShapedTypeComponents(newShapeValue, inputType));
11918dea784bSRob Suderman   return success();
11928dea784bSRob Suderman }
11938dea784bSRob Suderman 
1194db791b27SRamkumar Ramachandra llvm::LogicalResult tosa::ReshapeOp::verify() {
1195fbcd0c65SRafael Ubal   TensorType inputType = getInput1().getType();
1196fbcd0c65SRafael Ubal   RankedTensorType outputType = getType();
1197a315534eSa.puschin 
1198fbcd0c65SRafael Ubal   if ((int64_t)getNewShape().size() != outputType.getRank())
1199fbcd0c65SRafael Ubal     return emitOpError() << "new shape does not match result rank";
1200fbcd0c65SRafael Ubal 
1201fbcd0c65SRafael Ubal   for (auto [newShapeDim, outputShapeDim] :
1202019fbcc4SLongsheng Mou        zip(getNewShape(), outputType.getShape())) {
1203fbcd0c65SRafael Ubal     if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1204fbcd0c65SRafael Ubal         newShapeDim != outputShapeDim)
1205fbcd0c65SRafael Ubal       return emitOpError() << "new shape is inconsistent with result shape";
1206fbcd0c65SRafael Ubal 
1207019fbcc4SLongsheng Mou     if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1208019fbcc4SLongsheng Mou       return emitOpError() << "new shape has invalid tensor dimension size "
1209019fbcc4SLongsheng Mou                            << newShapeDim;
1210019fbcc4SLongsheng Mou   }
1211019fbcc4SLongsheng Mou 
1212e6eb94d3SLongsheng Mou   if (inputType.hasStaticShape()) {
1213a315534eSa.puschin     int64_t inputElementsNum = inputType.getNumElements();
1214e6eb94d3SLongsheng Mou     if (outputType.hasStaticShape()) {
1215a315534eSa.puschin       int64_t outputElementsNum = outputType.getNumElements();
1216a315534eSa.puschin       if (inputElementsNum != outputElementsNum) {
1217fbcd0c65SRafael Ubal         return emitOpError() << "cannot reshape " << inputElementsNum
1218a315534eSa.puschin                              << " elements into " << outputElementsNum;
1219a315534eSa.puschin       }
1220a315534eSa.puschin     }
122126d896f3SRafael Ubal 
1222e6eb94d3SLongsheng Mou     int64_t newShapeElementsNum = std::accumulate(
1223e6eb94d3SLongsheng Mou         getNewShape().begin(), getNewShape().end(), 1LL,
1224e6eb94d3SLongsheng Mou         [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1225e6eb94d3SLongsheng Mou     bool isStaticNewShape =
1226e6eb94d3SLongsheng Mou         llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1227e6eb94d3SLongsheng Mou     if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1228e6eb94d3SLongsheng Mou         (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1229e6eb94d3SLongsheng Mou       return emitOpError() << "cannot reshape " << inputElementsNum
1230e6eb94d3SLongsheng Mou                            << " elements into " << newShapeElementsNum;
1231e6eb94d3SLongsheng Mou     }
1232e6eb94d3SLongsheng Mou   }
1233e6eb94d3SLongsheng Mou 
123426d896f3SRafael Ubal   int missingDims = llvm::count(getNewShape(), -1);
123526d896f3SRafael Ubal   if (missingDims > 1)
1236fbcd0c65SRafael Ubal     return emitOpError() << "expected at most one target dimension to be -1";
123726d896f3SRafael Ubal 
1238a315534eSa.puschin   return mlir::success();
1239a315534eSa.puschin }
1240a315534eSa.puschin 
1241a54efdbdSArteen Abrishami LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1242ded988edSAviad Cohen   // Perms must be constants.
1243ded988edSAviad Cohen   DenseIntElementsAttr permsAttr;
1244ded988edSAviad Cohen   if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1245ded988edSAviad Cohen     return failure();
1246ded988edSAviad Cohen 
1247a54efdbdSArteen Abrishami   perms.clear();
1248a54efdbdSArteen Abrishami   for (auto v : permsAttr.getValues<APInt>())
1249a54efdbdSArteen Abrishami     perms.push_back(v.getSExtValue());
1250ded988edSAviad Cohen 
1251ded988edSAviad Cohen   return success();
1252ded988edSAviad Cohen }
1253ded988edSAviad Cohen 
12545a4e7760SRob Suderman LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
125522426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1256057fc8e7SAmanda Tang     TransposeOp::Adaptor adaptor,
12575a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1258057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput1().getType());
1259057fc8e7SAmanda Tang   ShapeAdaptor permsShape(adaptor.getPerms().getType());
12605a4e7760SRob Suderman 
12617563eb64SFelix Schneider   // We cannot infer anything from a rank-0 "permutation" tensor.
12627563eb64SFelix Schneider   if (permsShape.hasRank() && permsShape.getRank() == 0)
12637563eb64SFelix Schneider     return failure();
12647563eb64SFelix Schneider 
12655a4e7760SRob Suderman   // If input rank and permutation length is unknown, the output rank is
12665a4e7760SRob Suderman   // unknown.
1267826d3eaaSRob Suderman   if (!inputShape.hasRank() || !permsShape.hasRank() ||
1268826d3eaaSRob Suderman       permsShape.isDynamicDim(0)) {
12695a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents());
12705a4e7760SRob Suderman     return success();
12715a4e7760SRob Suderman   }
12725a4e7760SRob Suderman 
1273a54efdbdSArteen Abrishami   // This would imply the number of permutations does not match the rank of
1274a54efdbdSArteen Abrishami   // the input which is illegal.
1275826d3eaaSRob Suderman   if (permsShape.getDimSize(0) != inputShape.getRank()) {
1276826d3eaaSRob Suderman     return failure();
1277826d3eaaSRob Suderman   }
1278826d3eaaSRob Suderman 
12795a4e7760SRob Suderman   SmallVector<int64_t> outputShape;
12805a4e7760SRob Suderman   // Rank-0 means no permutations matter.
128109349303SJacques Pienaar   if (inputShape.getRank() == 0) {
12825a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
12835a4e7760SRob Suderman     return success();
12845a4e7760SRob Suderman   }
12855a4e7760SRob Suderman 
12865a4e7760SRob Suderman   // Check whether the input dimensions are all the same.
12875a4e7760SRob Suderman   bool allTheSame = true;
128809349303SJacques Pienaar   for (int i = 1, s = inputShape.getRank(); i < s; i++) {
128909349303SJacques Pienaar     if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
12905a4e7760SRob Suderman       allTheSame = false;
12915a4e7760SRob Suderman       break;
12925a4e7760SRob Suderman     }
12935a4e7760SRob Suderman   }
12945a4e7760SRob Suderman 
12955a4e7760SRob Suderman   // If all of the input dimensions are the same we don't care about the
12965a4e7760SRob Suderman   // permutation.
12975a4e7760SRob Suderman   if (allTheSame) {
129809349303SJacques Pienaar     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
12995a4e7760SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
13005a4e7760SRob Suderman     return success();
13015a4e7760SRob Suderman   }
13025a4e7760SRob Suderman 
1303399638f9SAliia Khasanova   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
13045a4e7760SRob Suderman   // If the permuations are a constant we can directly determine the output
13055a4e7760SRob Suderman   // shape.
1306057fc8e7SAmanda Tang   DenseIntElementsAttr attr;
1307057fc8e7SAmanda Tang   if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1308057fc8e7SAmanda Tang       attr.getType().getRank() == 1) {
1309057fc8e7SAmanda Tang     ShapeAdaptor permShape = attr;
131047793481SKai Sasaki     // Constant permutation must be the same length as the input rank.
131147793481SKai Sasaki     if (inputShape.getRank() != permShape.getRank())
131247793481SKai Sasaki       return emitOptionalError(location,
131347793481SKai Sasaki                                "constant permutation must be the same length"
131447793481SKai Sasaki                                " as the input rank");
131547793481SKai Sasaki 
131647793481SKai Sasaki     // Constant permutation values must be within the input rank.
131747793481SKai Sasaki     for (int i = 0, e = inputShape.getRank(); i < e; i++) {
131847793481SKai Sasaki       if (inputShape.getRank() <= permShape.getDimSize(i))
131947793481SKai Sasaki         return failure();
132047793481SKai Sasaki     }
132147793481SKai Sasaki 
132209349303SJacques Pienaar     outputShape.reserve(inputShape.getRank());
132309349303SJacques Pienaar     for (int i = 0, s = inputShape.getRank(); i < s; i++) {
132409349303SJacques Pienaar       outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
13255a4e7760SRob Suderman     }
13265a4e7760SRob Suderman   }
13275a4e7760SRob Suderman 
13285a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
13295a4e7760SRob Suderman   return success();
13305a4e7760SRob Suderman }
13315a4e7760SRob Suderman 
13328190369eSFelix Schneider LogicalResult tosa::TransposeOp::verify() {
13338190369eSFelix Schneider   TensorType inputType = getInput1().getType();
13348190369eSFelix Schneider   TensorType permType = getPerms().getType();
13358190369eSFelix Schneider   TensorType outputType = getOutput().getType();
13368190369eSFelix Schneider 
13378190369eSFelix Schneider   if (permType.hasRank() && permType.getRank() != 1)
13388190369eSFelix Schneider     return emitOpError()
13398190369eSFelix Schneider            << "expected permutation tensor to be rank 1 but got rank "
13408190369eSFelix Schneider            << permType.getRank();
13418190369eSFelix Schneider   if (inputType.hasRank() && permType.hasRank())
13428190369eSFelix Schneider     if (!permType.isDynamicDim(0) &&
13438190369eSFelix Schneider         permType.getDimSize(0) != inputType.getRank())
13448190369eSFelix Schneider       return emitOpError() << "expected permutation tensor dim 0 to have size "
13458190369eSFelix Schneider                            << inputType.getRank()
13468190369eSFelix Schneider                            << " (input rank) but got size "
13478190369eSFelix Schneider                            << permType.getDimSize(0);
13488190369eSFelix Schneider   if (inputType.hasRank() && outputType.hasRank() &&
13498190369eSFelix Schneider       inputType.getRank() != outputType.getRank())
13508190369eSFelix Schneider     return emitOpError()
13518190369eSFelix Schneider            << "expected input tensor rank to equal result tensor rank";
13528190369eSFelix Schneider   if (outputType.hasRank() && permType.hasRank())
13538190369eSFelix Schneider     if (!permType.isDynamicDim(0) &&
13548190369eSFelix Schneider         permType.getDimSize(0) != outputType.getRank())
13558190369eSFelix Schneider       return emitOpError() << "expected permutation tensor dim 0 to have size "
13568190369eSFelix Schneider                            << outputType.getRank()
13578190369eSFelix Schneider                            << " (output rank) but got size "
13588190369eSFelix Schneider                            << permType.getDimSize(0);
13598190369eSFelix Schneider 
1360a54efdbdSArteen Abrishami   SmallVector<int32_t> constantPerms;
13618190369eSFelix Schneider   if (succeeded(getConstantPerms(constantPerms))) {
1362a54efdbdSArteen Abrishami     // Assert that the permutation tensor has a rank, which means that the
1363a54efdbdSArteen Abrishami     // rank has been verified above.
13648190369eSFelix Schneider     assert(permType.hasRank() &&
13658190369eSFelix Schneider            "Unexpectedly found permutation tensor without rank");
1366a54efdbdSArteen Abrishami     if (!llvm::all_of(constantPerms,
1367a54efdbdSArteen Abrishami                       [&constantPerms](int32_t s) {
1368a54efdbdSArteen Abrishami                         return s >= 0 &&
1369a54efdbdSArteen Abrishami                                static_cast<size_t>(s) < constantPerms.size();
1370a54efdbdSArteen Abrishami                       }) ||
1371a54efdbdSArteen Abrishami         !isPermutationVector(llvm::to_vector(llvm::map_range(
1372a54efdbdSArteen Abrishami             constantPerms, [](int32_t v) -> int64_t { return v; }))))
13738190369eSFelix Schneider       return emitOpError() << "expected valid permutation tensor";
1374c8b5d30fSDarshanRamakant 
1375a54efdbdSArteen Abrishami     // Verify that the types of the input and output tensors are properly
1376a54efdbdSArteen Abrishami     // permuted.
1377a54efdbdSArteen Abrishami     if (inputType.hasRank() && outputType.hasRank()) {
1378a54efdbdSArteen Abrishami       assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1379a54efdbdSArteen Abrishami              inputType.getRank() == outputType.getRank());
1380a54efdbdSArteen Abrishami 
1381a54efdbdSArteen Abrishami       for (auto i = 0; i < outputType.getRank(); i++) {
1382a54efdbdSArteen Abrishami         if (inputType.isDynamicDim(constantPerms[i]) ||
1383a54efdbdSArteen Abrishami             outputType.isDynamicDim(i))
1384a54efdbdSArteen Abrishami           continue;
1385a54efdbdSArteen Abrishami 
1386a54efdbdSArteen Abrishami         if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1387a54efdbdSArteen Abrishami           return emitOpError()
1388a54efdbdSArteen Abrishami                  << "expected output tensor dim " << i << " to match "
1389a54efdbdSArteen Abrishami                  << "input dim " << constantPerms[i] << " with value of "
1390a54efdbdSArteen Abrishami                  << inputType.getDimSize(constantPerms[i]);
1391a54efdbdSArteen Abrishami       }
1392c8b5d30fSDarshanRamakant     }
13938190369eSFelix Schneider   }
13948190369eSFelix Schneider   return success();
13958190369eSFelix Schneider }
13968190369eSFelix Schneider 
13975cd074faSMaya Amrami LogicalResult TransposeOp::reifyResultShapes(
13985cd074faSMaya Amrami     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
13995cd074faSMaya Amrami 
1400a54efdbdSArteen Abrishami   SmallVector<int32_t> transposePerms;
14015cd074faSMaya Amrami   if (getConstantPerms(transposePerms).failed())
14025cd074faSMaya Amrami     return failure();
14035cd074faSMaya Amrami 
14045cd074faSMaya Amrami   Value input = getInput1();
1405d2353695SPeiming Liu   auto inputType = cast<TensorType>(input.getType());
14065cd074faSMaya Amrami 
14075cd074faSMaya Amrami   SmallVector<OpFoldResult> returnedDims(inputType.getRank());
14085cd074faSMaya Amrami   for (auto dim : transposePerms) {
1409a54efdbdSArteen Abrishami     int32_t dimInInput = transposePerms[dim];
14105cd074faSMaya Amrami     if (inputType.isDynamicDim(dimInInput))
14115cd074faSMaya Amrami       returnedDims[dim] =
14125cd074faSMaya Amrami           builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
14135cd074faSMaya Amrami               .getResult();
14145cd074faSMaya Amrami     else
14155cd074faSMaya Amrami       returnedDims[dim] =
14165cd074faSMaya Amrami           builder.getIndexAttr(inputType.getDimSize(dimInInput));
14175cd074faSMaya Amrami   }
14185cd074faSMaya Amrami 
14195cd074faSMaya Amrami   reifiedReturnShapes.emplace_back(std::move(returnedDims));
14205cd074faSMaya Amrami   return success();
14215cd074faSMaya Amrami }
14225cd074faSMaya Amrami 
14235a4e7760SRob Suderman LogicalResult tosa::GatherOp::inferReturnTypeComponents(
142422426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1425057fc8e7SAmanda Tang     GatherOp::Adaptor adaptor,
14265a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
14275a4e7760SRob Suderman   llvm::SmallVector<int64_t> outputShape;
1428399638f9SAliia Khasanova   outputShape.resize(3, ShapedType::kDynamic);
14295a4e7760SRob Suderman 
1430057fc8e7SAmanda Tang   ShapeAdaptor valuesShape(adaptor.getValues().getType());
143109349303SJacques Pienaar   if (valuesShape.hasRank()) {
143209349303SJacques Pienaar     outputShape[0] = valuesShape.getDimSize(0);
143309349303SJacques Pienaar     outputShape[2] = valuesShape.getDimSize(2);
14345a4e7760SRob Suderman   }
14355a4e7760SRob Suderman 
1436057fc8e7SAmanda Tang   ShapeAdaptor indicesShape(adaptor.getIndices().getType());
143709349303SJacques Pienaar   if (indicesShape.hasRank()) {
1438399638f9SAliia Khasanova     if (outputShape[0] == ShapedType::kDynamic)
143909349303SJacques Pienaar       outputShape[0] = indicesShape.getDimSize(0);
1440399638f9SAliia Khasanova     if (outputShape[1] == ShapedType::kDynamic)
144109349303SJacques Pienaar       outputShape[1] = indicesShape.getDimSize(1);
14425a4e7760SRob Suderman   }
14435a4e7760SRob Suderman 
14445a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
14455a4e7760SRob Suderman   return success();
14465a4e7760SRob Suderman }
14475a4e7760SRob Suderman 
1448143edecaSRob Suderman LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
144922426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1450057fc8e7SAmanda Tang     ResizeOp::Adaptor adaptor,
1451143edecaSRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1452143edecaSRob Suderman   llvm::SmallVector<int64_t, 4> outputShape;
1453399638f9SAliia Khasanova   outputShape.resize(4, ShapedType::kDynamic);
1454143edecaSRob Suderman 
1455057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
1456ff23599aSTatWai Chong   if (!inputShape.hasRank())
1457ff23599aSTatWai Chong     return failure();
1458ff23599aSTatWai Chong 
145909349303SJacques Pienaar   outputShape[0] = inputShape.getDimSize(0);
146009349303SJacques Pienaar   outputShape[3] = inputShape.getDimSize(3);
1461fb4cedccSAliia Khasanova   int64_t inputHeight = inputShape.getDimSize(1);
1462fb4cedccSAliia Khasanova   int64_t inputWidth = inputShape.getDimSize(2);
1463143edecaSRob Suderman 
1464399638f9SAliia Khasanova   if ((inputHeight == ShapedType::kDynamic) ||
1465399638f9SAliia Khasanova       (inputWidth == ShapedType::kDynamic))
1466ff23599aSTatWai Chong     return failure();
1467143edecaSRob Suderman 
146811030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
146911030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
147011030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1471143edecaSRob Suderman 
1472ff23599aSTatWai Chong   // Compute the output shape based on attributes: scale, offset, and border.
1473ff23599aSTatWai Chong   outputShape[1] =
1474ff23599aSTatWai Chong       (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1475ff23599aSTatWai Chong        scaleInt[1]) +
1476ff23599aSTatWai Chong       1;
1477143edecaSRob Suderman 
1478ff23599aSTatWai Chong   outputShape[2] =
1479ff23599aSTatWai Chong       (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1480ff23599aSTatWai Chong        scaleInt[3]) +
1481ff23599aSTatWai Chong       1;
1482143edecaSRob Suderman 
1483143edecaSRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1484143edecaSRob Suderman   return success();
1485143edecaSRob Suderman }
1486143edecaSRob Suderman 
14875a4e7760SRob Suderman LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
148822426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1489057fc8e7SAmanda Tang     ScatterOp::Adaptor adaptor,
14905a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
14915a4e7760SRob Suderman   llvm::SmallVector<int64_t> outputShape;
1492399638f9SAliia Khasanova   outputShape.resize(3, ShapedType::kDynamic);
14935a4e7760SRob Suderman 
1494057fc8e7SAmanda Tang   ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
149509349303SJacques Pienaar   if (valuesInShape.hasRank()) {
149609349303SJacques Pienaar     outputShape[0] = valuesInShape.getDimSize(0);
149709349303SJacques Pienaar     outputShape[1] = valuesInShape.getDimSize(1);
149809349303SJacques Pienaar     outputShape[2] = valuesInShape.getDimSize(2);
14995a4e7760SRob Suderman   }
15005a4e7760SRob Suderman 
1501057fc8e7SAmanda Tang   ShapeAdaptor indicesShape(adaptor.getIndices().getType());
150209349303SJacques Pienaar   if (indicesShape.hasRank()) {
1503399638f9SAliia Khasanova     if (outputShape[0] == ShapedType::kDynamic)
150409349303SJacques Pienaar       outputShape[0] = indicesShape.getDimSize(0);
15055a4e7760SRob Suderman   }
15065a4e7760SRob Suderman 
1507057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
150809349303SJacques Pienaar   if (inputShape.hasRank()) {
1509399638f9SAliia Khasanova     if (outputShape[0] == ShapedType::kDynamic)
151009349303SJacques Pienaar       outputShape[0] = inputShape.getDimSize(0);
1511399638f9SAliia Khasanova     if (outputShape[2] == ShapedType::kDynamic)
151209349303SJacques Pienaar       outputShape[2] = inputShape.getDimSize(2);
15135a4e7760SRob Suderman   }
15145a4e7760SRob Suderman 
15155a4e7760SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
15165a4e7760SRob Suderman   return success();
15175a4e7760SRob Suderman }
15185a4e7760SRob Suderman 
15195a4e7760SRob Suderman static LogicalResult ReduceInferReturnTypes(
15203500e110SAviad Cohen     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
15215a4e7760SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
15228a57bc09SFelix Schneider   int64_t axisVal = axis.getValue().getSExtValue();
15238a57bc09SFelix Schneider   if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
15243500e110SAviad Cohen     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
15255a4e7760SRob Suderman     return success();
15265a4e7760SRob Suderman   }
15275a4e7760SRob Suderman 
15285a4e7760SRob Suderman   SmallVector<int64_t> outputShape;
152909349303SJacques Pienaar   operandShape.getDims(outputShape);
15305a4e7760SRob Suderman   outputShape[axisVal] = 1;
15313500e110SAviad Cohen   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
15325a4e7760SRob Suderman   return success();
15335a4e7760SRob Suderman }
15345a4e7760SRob Suderman 
15353500e110SAviad Cohen #define COMPATIBLE_RETURN_TYPES(OP)                                            \
15363500e110SAviad Cohen   bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) {                 \
15373500e110SAviad Cohen     if (l.size() != r.size() || l.size() != 1)                                 \
15383500e110SAviad Cohen       return false;                                                            \
15393500e110SAviad Cohen     if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))              \
15403500e110SAviad Cohen       return false;                                                            \
15413500e110SAviad Cohen     return succeeded(verifyCompatibleShape(l[0], r[0]));                       \
15423500e110SAviad Cohen   }
15433500e110SAviad Cohen 
15445a4e7760SRob Suderman #define REDUCE_SHAPE_INFER(OP)                                                 \
15455a4e7760SRob Suderman   LogicalResult OP::inferReturnTypeComponents(                                 \
154622426110SRamkumar Ramachandra       MLIRContext *context, ::std::optional<Location> location,                \
1547057fc8e7SAmanda Tang       OP::Adaptor adaptor,                                                     \
15485a4e7760SRob Suderman       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
15493500e110SAviad Cohen     Type inputType =                                                           \
1550057fc8e7SAmanda Tang         llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1551057fc8e7SAmanda Tang     ShapeAdaptor inputShape(adaptor.getInput().getType());                     \
1552057fc8e7SAmanda Tang     const Properties &prop = adaptor.getProperties();                          \
1553057fc8e7SAmanda Tang     return ReduceInferReturnTypes(inputShape, inputType, prop.axis,            \
15545a4e7760SRob Suderman                                   inferredReturnShapes);                       \
15553500e110SAviad Cohen   }                                                                            \
15563500e110SAviad Cohen   COMPATIBLE_RETURN_TYPES(OP)
15575a4e7760SRob Suderman 
15585a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
15595a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
15605a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
15615a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
15625a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
15635a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
15645a4e7760SRob Suderman #undef REDUCE_SHAPE_INFER
15653500e110SAviad Cohen COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
15663500e110SAviad Cohen #undef COMPATIBLE_RETURN_TYPES
15675a4e7760SRob Suderman 
15688a57bc09SFelix Schneider template <typename T>
15698a57bc09SFelix Schneider static LogicalResult verifyReduceOp(T op) {
15708a57bc09SFelix Schneider   // All TOSA reduce Ops have input, output and axis.
15718a57bc09SFelix Schneider   TensorType inputType = op.getInput().getType();
15728a57bc09SFelix Schneider   TensorType outputType = op.getOutput().getType();
15738a57bc09SFelix Schneider   int32_t reduceAxis = op.getAxis();
15748a57bc09SFelix Schneider 
15758a57bc09SFelix Schneider   if (reduceAxis < 0) {
15768a57bc09SFelix Schneider     op.emitOpError("reduce axis must not be negative");
15778a57bc09SFelix Schneider     return failure();
15788a57bc09SFelix Schneider   }
15798a57bc09SFelix Schneider   if (inputType.hasRank()) {
15808a57bc09SFelix Schneider     int64_t inputRank = inputType.getRank();
15818a57bc09SFelix Schneider     // We allow for a special case where the input/output shape has rank 0 and
15828a57bc09SFelix Schneider     // axis is also 0.
15838a57bc09SFelix Schneider     if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
15848a57bc09SFelix Schneider       op.emitOpError("expect input tensor rank (")
15858a57bc09SFelix Schneider           << inputRank << ") to be larger than reduce axis (" << reduceAxis
15868a57bc09SFelix Schneider           << ")";
15878a57bc09SFelix Schneider       return failure();
15888a57bc09SFelix Schneider     }
15898a57bc09SFelix Schneider   }
15908a57bc09SFelix Schneider   if (outputType.hasRank()) {
15918a57bc09SFelix Schneider     int64_t outputRank = outputType.getRank();
15928a57bc09SFelix Schneider     if (inputType.hasRank() && outputRank != inputType.getRank()) {
15938a57bc09SFelix Schneider       op.emitOpError(
15948a57bc09SFelix Schneider           "expect output tensor rank to be equal to input tensor rank");
15958a57bc09SFelix Schneider       return failure();
15968a57bc09SFelix Schneider     }
15978a57bc09SFelix Schneider     if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
15988a57bc09SFelix Schneider       op.emitOpError("expect output tensor rank (")
15998a57bc09SFelix Schneider           << outputRank << ") to be larger than reduce axis (" << reduceAxis
16008a57bc09SFelix Schneider           << ")";
16018a57bc09SFelix Schneider       return failure();
16028a57bc09SFelix Schneider     }
1603a54efdbdSArteen Abrishami     // We can only verify the reduced dimension size to be 1 if this is not
1604a54efdbdSArteen Abrishami     // the special case of output rank == 0.
16058a57bc09SFelix Schneider     if (outputRank != 0) {
16068a57bc09SFelix Schneider       auto outputShape = outputType.getShape();
16078a57bc09SFelix Schneider       if (!outputType.isDynamicDim(reduceAxis) &&
16088a57bc09SFelix Schneider           outputShape[reduceAxis] != 1) {
16098a57bc09SFelix Schneider         op.emitOpError("expect reduced dimension size to be 1, got ")
16108a57bc09SFelix Schneider             << outputShape[reduceAxis];
16118a57bc09SFelix Schneider         return failure();
16128a57bc09SFelix Schneider       }
16138a57bc09SFelix Schneider     }
16148a57bc09SFelix Schneider   }
16158a57bc09SFelix Schneider   return success();
16168a57bc09SFelix Schneider }
16178a57bc09SFelix Schneider 
16188a57bc09SFelix Schneider LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
16198a57bc09SFelix Schneider LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
16208a57bc09SFelix Schneider LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
16218a57bc09SFelix Schneider LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
16228a57bc09SFelix Schneider LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
16238a57bc09SFelix Schneider LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
16248a57bc09SFelix Schneider 
16258dea784bSRob Suderman static LogicalResult NAryInferReturnTypes(
162609349303SJacques Pienaar     const ValueShapeRange &operands,
16278dea784bSRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
16288dea784bSRob Suderman   llvm::SmallVector<int64_t> outShape;
16298dea784bSRob Suderman   if (resolveBroadcastShape(operands, outShape).failed()) {
16308dea784bSRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents());
16318dea784bSRob Suderman   } else {
16328dea784bSRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
16338dea784bSRob Suderman   }
16348dea784bSRob Suderman   return success();
16358dea784bSRob Suderman }
16368dea784bSRob Suderman 
16378dea784bSRob Suderman #define NARY_SHAPE_INFER(OP)                                                   \
16388dea784bSRob Suderman   LogicalResult OP::inferReturnTypeComponents(                                 \
163922426110SRamkumar Ramachandra       MLIRContext *context, ::std::optional<Location> location,                \
1640d425f589SJacques Pienaar       ValueShapeRange operands, DictionaryAttr attributes,                     \
16415e118f93SMehdi Amini       OpaqueProperties properties, RegionRange regions,                        \
16428dea784bSRob Suderman       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
16438dea784bSRob Suderman     return NAryInferReturnTypes(operands, inferredReturnShapes);               \
16448dea784bSRob Suderman   }
16458dea784bSRob Suderman 
16468dea784bSRob Suderman NARY_SHAPE_INFER(tosa::AbsOp)
16478dea784bSRob Suderman NARY_SHAPE_INFER(tosa::AddOp)
16488dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
16498dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseAndOp)
16508dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseOrOp)
16518dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseXorOp)
16528dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1653143edecaSRob Suderman NARY_SHAPE_INFER(tosa::CastOp)
16548dea784bSRob Suderman NARY_SHAPE_INFER(tosa::CeilOp)
16558dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ClampOp)
16568dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ClzOp)
1657d57f158aSJerry-Ge NARY_SHAPE_INFER(tosa::CosOp)
16588dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ExpOp)
16598dea784bSRob Suderman NARY_SHAPE_INFER(tosa::FloorOp)
16608dea784bSRob Suderman NARY_SHAPE_INFER(tosa::GreaterEqualOp)
16618dea784bSRob Suderman NARY_SHAPE_INFER(tosa::GreaterOp)
1662143edecaSRob Suderman NARY_SHAPE_INFER(tosa::IdentityOp)
166382383d5fSTai Ly NARY_SHAPE_INFER(tosa::IntDivOp)
16648dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogOp)
16658dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalAndOp)
16668dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
16678dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalNotOp)
16688dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalOrOp)
16698dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
16708dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalXorOp)
16718dea784bSRob Suderman NARY_SHAPE_INFER(tosa::MaximumOp)
16728dea784bSRob Suderman NARY_SHAPE_INFER(tosa::MinimumOp)
16738dea784bSRob Suderman NARY_SHAPE_INFER(tosa::MulOp)
16748dea784bSRob Suderman NARY_SHAPE_INFER(tosa::NegateOp)
16758dea784bSRob Suderman NARY_SHAPE_INFER(tosa::PowOp)
16768dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ReciprocalOp)
1677143edecaSRob Suderman NARY_SHAPE_INFER(tosa::RescaleOp)
16788dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ReverseOp)
16798dea784bSRob Suderman NARY_SHAPE_INFER(tosa::RsqrtOp)
1680d57f158aSJerry-Ge NARY_SHAPE_INFER(tosa::SinOp)
16818dea784bSRob Suderman NARY_SHAPE_INFER(tosa::SelectOp)
16828dea784bSRob Suderman NARY_SHAPE_INFER(tosa::SubOp)
16838dea784bSRob Suderman NARY_SHAPE_INFER(tosa::TanhOp)
16841fef1f97SManupa Karunaratne NARY_SHAPE_INFER(tosa::ErfOp)
16858dea784bSRob Suderman NARY_SHAPE_INFER(tosa::SigmoidOp)
16868dea784bSRob Suderman #undef PRED_SHAPE_INFER
16878dea784bSRob Suderman 
1688f2832c22SRob Suderman static LogicalResult poolingInferReturnTypes(
1689057fc8e7SAmanda Tang     ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1690057fc8e7SAmanda Tang     ArrayRef<int64_t> pad,
1691f2832c22SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1692f2832c22SRob Suderman   llvm::SmallVector<int64_t> outputShape;
1693399638f9SAliia Khasanova   outputShape.resize(4, ShapedType::kDynamic);
1694f2832c22SRob Suderman 
1695f2832c22SRob Suderman   // We only know the rank if the input type is unranked.
169609349303SJacques Pienaar   if (!inputShape) {
1697f2832c22SRob Suderman     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1698f2832c22SRob Suderman     return success();
1699f2832c22SRob Suderman   }
1700f2832c22SRob Suderman 
1701f2832c22SRob Suderman   // Batch and number of channels are identical for pooling layer.
170209349303SJacques Pienaar   outputShape[0] = inputShape.getDimSize(0);
170309349303SJacques Pienaar   outputShape[3] = inputShape.getDimSize(3);
1704f2832c22SRob Suderman 
1705fb4cedccSAliia Khasanova   int64_t height = inputShape.getDimSize(1);
1706fb4cedccSAliia Khasanova   int64_t width = inputShape.getDimSize(2);
1707f2832c22SRob Suderman 
1708fb4cedccSAliia Khasanova   if (!ShapedType::isDynamic(height)) {
1709fb4cedccSAliia Khasanova     int64_t padded = height + pad[0] + pad[1] - kernel[0];
1710f2832c22SRob Suderman     outputShape[1] = padded / stride[0] + 1;
1711f2832c22SRob Suderman   }
1712f2832c22SRob Suderman 
1713fb4cedccSAliia Khasanova   if (!ShapedType::isDynamic(width)) {
1714fb4cedccSAliia Khasanova     int64_t padded = width + pad[2] + pad[3] - kernel[1];
1715f2832c22SRob Suderman     outputShape[2] = padded / stride[1] + 1;
1716f2832c22SRob Suderman   }
1717f2832c22SRob Suderman 
1718f2832c22SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1719f2832c22SRob Suderman   return success();
1720f2832c22SRob Suderman }
1721f2832c22SRob Suderman 
172211dda1a2SRob Suderman LogicalResult Conv2DOp::inferReturnTypeComponents(
172322426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1724057fc8e7SAmanda Tang     Conv2DOp::Adaptor adaptor,
172511dda1a2SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1726399638f9SAliia Khasanova   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
172711dda1a2SRob Suderman 
1728399638f9SAliia Khasanova   int64_t inputWidth = ShapedType::kDynamic;
1729399638f9SAliia Khasanova   int64_t inputHeight = ShapedType::kDynamic;
1730399638f9SAliia Khasanova   int64_t weightWidth = ShapedType::kDynamic;
1731399638f9SAliia Khasanova   int64_t weightHeight = ShapedType::kDynamic;
173211dda1a2SRob Suderman 
173311dda1a2SRob Suderman   // Input shape describes input width/height and batch.
173409349303SJacques Pienaar 
1735057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
173609349303SJacques Pienaar   if (inputShape.hasRank()) {
173709349303SJacques Pienaar     outputShape[0] = inputShape.getDimSize(0);
173809349303SJacques Pienaar     inputHeight = inputShape.getDimSize(1);
173909349303SJacques Pienaar     inputWidth = inputShape.getDimSize(2);
174011dda1a2SRob Suderman   }
174111dda1a2SRob Suderman 
174211dda1a2SRob Suderman   // Weight shapes describes the filter width/height and the output channels.
1743057fc8e7SAmanda Tang   ShapeAdaptor weightShape(adaptor.getWeight().getType());
174409349303SJacques Pienaar   if (weightShape.hasRank()) {
174509349303SJacques Pienaar     outputShape[3] = weightShape.getDimSize(0);
174609349303SJacques Pienaar     weightHeight = weightShape.getDimSize(1);
174709349303SJacques Pienaar     weightWidth = weightShape.getDimSize(2);
174811dda1a2SRob Suderman   }
174911dda1a2SRob Suderman 
175011dda1a2SRob Suderman   // Bias shape can describe the output channels.
1751057fc8e7SAmanda Tang   ShapeAdaptor biasShape(adaptor.getBias().getType());
175209349303SJacques Pienaar   if (biasShape.hasRank()) {
175311dda1a2SRob Suderman     outputShape[3] = ShapedType::isDynamic(outputShape[3])
175409349303SJacques Pienaar                          ? biasShape.getDimSize(0)
175511dda1a2SRob Suderman                          : outputShape[3];
175611dda1a2SRob Suderman   }
175711dda1a2SRob Suderman 
175811030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
175911030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
176011030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> padding = adaptor.getPad();
176111dda1a2SRob Suderman 
176211dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputHeight) &&
176311dda1a2SRob Suderman       !ShapedType::isDynamic(weightHeight)) {
1764fb4cedccSAliia Khasanova     int64_t inputSize = inputHeight + padding[0] + padding[1];
1765fb4cedccSAliia Khasanova     int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1766fb4cedccSAliia Khasanova     int64_t unstridedResult = inputSize - filterSize + 1;
176711dda1a2SRob Suderman     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
176811dda1a2SRob Suderman   }
176911dda1a2SRob Suderman 
177011dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputWidth) &&
177111dda1a2SRob Suderman       !ShapedType::isDynamic(weightWidth)) {
1772fb4cedccSAliia Khasanova     int64_t inputSize = inputWidth + padding[2] + padding[3];
1773fb4cedccSAliia Khasanova     int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1774fb4cedccSAliia Khasanova     int64_t unstridedResult = inputSize - filterSize + 1;
177511dda1a2SRob Suderman     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
177611dda1a2SRob Suderman   }
177711dda1a2SRob Suderman 
177811dda1a2SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
177911dda1a2SRob Suderman   return success();
178011dda1a2SRob Suderman }
178111dda1a2SRob Suderman 
1782360a03c9SJack Frankland LogicalResult Conv2DOp::verify() {
1783360a03c9SJack Frankland   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1784360a03c9SJack Frankland     return failure();
1785360a03c9SJack Frankland   return success();
1786360a03c9SJack Frankland }
17871be88f5aSRiver Riddle 
178811dda1a2SRob Suderman LogicalResult Conv3DOp::inferReturnTypeComponents(
178922426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1790057fc8e7SAmanda Tang     Conv3DOp::Adaptor adaptor,
179111dda1a2SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1792399638f9SAliia Khasanova   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
179311dda1a2SRob Suderman 
1794399638f9SAliia Khasanova   int64_t inputWidth = ShapedType::kDynamic;
1795399638f9SAliia Khasanova   int64_t inputHeight = ShapedType::kDynamic;
1796399638f9SAliia Khasanova   int64_t inputDepth = ShapedType::kDynamic;
179711dda1a2SRob Suderman 
1798399638f9SAliia Khasanova   int64_t weightWidth = ShapedType::kDynamic;
1799399638f9SAliia Khasanova   int64_t weightHeight = ShapedType::kDynamic;
1800399638f9SAliia Khasanova   int64_t weightDepth = ShapedType::kDynamic;
180111dda1a2SRob Suderman 
180211dda1a2SRob Suderman   // Input shape describes input width/height and batch.
1803057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
180409349303SJacques Pienaar   if (inputShape.hasRank()) {
180509349303SJacques Pienaar     outputShape[0] = inputShape.getDimSize(0);
1806eb04f321STatWai Chong     inputDepth = inputShape.getDimSize(1);
1807eb04f321STatWai Chong     inputHeight = inputShape.getDimSize(2);
1808eb04f321STatWai Chong     inputWidth = inputShape.getDimSize(3);
180911dda1a2SRob Suderman   }
181011dda1a2SRob Suderman 
181111dda1a2SRob Suderman   // Weight shapes describes the filter width/height and the output channels.
1812057fc8e7SAmanda Tang   ShapeAdaptor weightShape(adaptor.getWeight().getType());
181309349303SJacques Pienaar   if (weightShape.hasRank()) {
181409349303SJacques Pienaar     outputShape[4] = weightShape.getDimSize(0);
1815eb04f321STatWai Chong     weightDepth = weightShape.getDimSize(1);
1816eb04f321STatWai Chong     weightHeight = weightShape.getDimSize(2);
1817eb04f321STatWai Chong     weightWidth = weightShape.getDimSize(3);
181811dda1a2SRob Suderman   }
181911dda1a2SRob Suderman 
182011dda1a2SRob Suderman   // Bias shape can describe the output channels.
1821057fc8e7SAmanda Tang   ShapeAdaptor biasShape(adaptor.getBias().getType());
1822eb04f321STatWai Chong   if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1823eb04f321STatWai Chong     outputShape[4] = biasShape.getDimSize(0);
182411dda1a2SRob Suderman   }
182511dda1a2SRob Suderman 
182611030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
182711030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
182811030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> pad = adaptor.getPad();
182911dda1a2SRob Suderman 
1830eb04f321STatWai Chong   if (!ShapedType::isDynamic(inputDepth) &&
1831eb04f321STatWai Chong       !ShapedType::isDynamic(weightDepth)) {
1832eb04f321STatWai Chong     int32_t inputSize = inputDepth + pad[0] + pad[1];
1833eb04f321STatWai Chong     int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
183411dda1a2SRob Suderman     int32_t unstridedResult = inputSize - filterSize + 1;
183511dda1a2SRob Suderman     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
183611dda1a2SRob Suderman   }
183711dda1a2SRob Suderman 
1838eb04f321STatWai Chong   if (!ShapedType::isDynamic(inputHeight) &&
1839eb04f321STatWai Chong       !ShapedType::isDynamic(weightHeight)) {
1840eb04f321STatWai Chong     int32_t inputSize = inputHeight + pad[2] + pad[3];
1841eb04f321STatWai Chong     int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
184211dda1a2SRob Suderman     int32_t unstridedResult = inputSize - filterSize + 1;
184311dda1a2SRob Suderman     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
184411dda1a2SRob Suderman   }
184511dda1a2SRob Suderman 
1846eb04f321STatWai Chong   if (!ShapedType::isDynamic(inputWidth) &&
1847eb04f321STatWai Chong       !ShapedType::isDynamic(weightWidth)) {
1848eb04f321STatWai Chong     int32_t inputSize = inputWidth + pad[4] + pad[5];
1849eb04f321STatWai Chong     int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
185011dda1a2SRob Suderman     int32_t unstridedResult = inputSize - filterSize + 1;
185111dda1a2SRob Suderman     outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
185211dda1a2SRob Suderman   }
185311dda1a2SRob Suderman 
185411dda1a2SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
185511dda1a2SRob Suderman   return success();
185611dda1a2SRob Suderman }
185711dda1a2SRob Suderman 
1858360a03c9SJack Frankland LogicalResult Conv3DOp::verify() {
1859360a03c9SJack Frankland   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1860360a03c9SJack Frankland     return failure();
1861360a03c9SJack Frankland   return success();
1862360a03c9SJack Frankland }
18631be88f5aSRiver Riddle 
1864f2832c22SRob Suderman LogicalResult AvgPool2dOp::inferReturnTypeComponents(
186522426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1866057fc8e7SAmanda Tang     AvgPool2dOp::Adaptor adaptor,
1867f2832c22SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1868057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
1869057fc8e7SAmanda Tang   const Properties &prop = adaptor.getProperties();
1870057fc8e7SAmanda Tang   return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1871057fc8e7SAmanda Tang                                  inferredReturnShapes);
1872f2832c22SRob Suderman }
1873f2832c22SRob Suderman 
1874f2832c22SRob Suderman LogicalResult MaxPool2dOp::inferReturnTypeComponents(
187522426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1876057fc8e7SAmanda Tang     MaxPool2dOp::Adaptor adaptor,
1877f2832c22SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1878057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
1879057fc8e7SAmanda Tang   const Properties &prop = adaptor.getProperties();
1880057fc8e7SAmanda Tang   return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1881057fc8e7SAmanda Tang                                  inferredReturnShapes);
1882f2832c22SRob Suderman }
1883f2832c22SRob Suderman 
188411dda1a2SRob Suderman LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
188522426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1886057fc8e7SAmanda Tang     DepthwiseConv2DOp::Adaptor adaptor,
188711dda1a2SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1888399638f9SAliia Khasanova   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
188911dda1a2SRob Suderman 
1890399638f9SAliia Khasanova   int64_t inputWidth = ShapedType::kDynamic;
1891399638f9SAliia Khasanova   int64_t inputHeight = ShapedType::kDynamic;
1892399638f9SAliia Khasanova   int64_t inputChannels = ShapedType::kDynamic;
189311dda1a2SRob Suderman 
1894399638f9SAliia Khasanova   int64_t weightWidth = ShapedType::kDynamic;
1895399638f9SAliia Khasanova   int64_t weightHeight = ShapedType::kDynamic;
1896399638f9SAliia Khasanova   int64_t depthChannels = ShapedType::kDynamic;
189711dda1a2SRob Suderman 
189811dda1a2SRob Suderman   // Input shape describes input width/height and batch.
1899057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
190009349303SJacques Pienaar   if (inputShape.hasRank()) {
190109349303SJacques Pienaar     outputShape[0] = inputShape.getDimSize(0);
190209349303SJacques Pienaar     inputHeight = inputShape.getDimSize(1);
190309349303SJacques Pienaar     inputWidth = inputShape.getDimSize(2);
190409349303SJacques Pienaar     inputChannels = inputShape.getDimSize(3);
190511dda1a2SRob Suderman   }
190611dda1a2SRob Suderman 
190711dda1a2SRob Suderman   // Weight shapes describes the filter width/height and the output channels.
1908057fc8e7SAmanda Tang   ShapeAdaptor weightShape(adaptor.getWeight().getType());
190909349303SJacques Pienaar   if (weightShape.hasRank()) {
191009349303SJacques Pienaar     weightHeight = weightShape.getDimSize(0);
191109349303SJacques Pienaar     weightWidth = weightShape.getDimSize(1);
191211dda1a2SRob Suderman     inputChannels = ShapedType::isDynamic(inputChannels)
191309349303SJacques Pienaar                         ? weightShape.getDimSize(2)
191411dda1a2SRob Suderman                         : inputChannels;
191509349303SJacques Pienaar     depthChannels = weightShape.getDimSize(3);
191611dda1a2SRob Suderman   }
191711dda1a2SRob Suderman 
191811dda1a2SRob Suderman   // If both inputChannels and depthChannels are available we can determine
191911dda1a2SRob Suderman   // the output channels.
192011dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputChannels) &&
192111dda1a2SRob Suderman       !ShapedType::isDynamic(depthChannels)) {
192211dda1a2SRob Suderman     outputShape[3] = inputChannels * depthChannels;
192311dda1a2SRob Suderman   }
192411dda1a2SRob Suderman 
192511dda1a2SRob Suderman   // Bias shape can describe the output channels.
1926057fc8e7SAmanda Tang   ShapeAdaptor biasShape(adaptor.getBias().getType());
192709349303SJacques Pienaar   if (biasShape.hasRank()) {
192811dda1a2SRob Suderman     outputShape[3] = ShapedType::isDynamic(outputShape[3])
192909349303SJacques Pienaar                          ? biasShape.getDimSize(0)
193011dda1a2SRob Suderman                          : outputShape[3];
193111dda1a2SRob Suderman   }
193211dda1a2SRob Suderman 
193311030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
193411030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> padding = adaptor.getPad();
193511030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
193611dda1a2SRob Suderman 
193711dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputHeight) &&
193811dda1a2SRob Suderman       !ShapedType::isDynamic(weightHeight)) {
1939fb4cedccSAliia Khasanova     int64_t inputSize = inputHeight + padding[0] + padding[1];
1940fb4cedccSAliia Khasanova     int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1941fb4cedccSAliia Khasanova     int64_t unstridedResult = inputSize - filterSize + 1;
194211dda1a2SRob Suderman     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
194311dda1a2SRob Suderman   }
194411dda1a2SRob Suderman 
194511dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputWidth) &&
194611dda1a2SRob Suderman       !ShapedType::isDynamic(weightWidth)) {
1947fb4cedccSAliia Khasanova     int64_t inputSize = inputWidth + padding[2] + padding[3];
1948fb4cedccSAliia Khasanova     int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1949fb4cedccSAliia Khasanova     int64_t unstridedResult = inputSize - filterSize + 1;
195011dda1a2SRob Suderman     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
195111dda1a2SRob Suderman   }
195211dda1a2SRob Suderman 
195311dda1a2SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
195411dda1a2SRob Suderman   return success();
195511dda1a2SRob Suderman }
195611dda1a2SRob Suderman 
1957360a03c9SJack Frankland LogicalResult DepthwiseConv2DOp::verify() {
1958360a03c9SJack Frankland   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1959360a03c9SJack Frankland     return failure();
1960360a03c9SJack Frankland   return success();
1961360a03c9SJack Frankland }
19621be88f5aSRiver Riddle 
196311dda1a2SRob Suderman LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
196422426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
1965057fc8e7SAmanda Tang     TransposeConv2DOp::Adaptor adaptor,
196611dda1a2SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
196711030c7dSAlexander Shaposhnikov   // outputShape is mutable.
196811030c7dSAlexander Shaposhnikov   llvm::SmallVector<int64_t> outputShape =
196911030c7dSAlexander Shaposhnikov       convertToMlirShape(adaptor.getOutShape());
197011dda1a2SRob Suderman 
1971399638f9SAliia Khasanova   int64_t inputWidth = ShapedType::kDynamic;
1972399638f9SAliia Khasanova   int64_t inputHeight = ShapedType::kDynamic;
1973399638f9SAliia Khasanova   int64_t weightWidth = ShapedType::kDynamic;
1974399638f9SAliia Khasanova   int64_t weightHeight = ShapedType::kDynamic;
197511dda1a2SRob Suderman 
197611dda1a2SRob Suderman   // Input shape describes input width/height and batch.
1977057fc8e7SAmanda Tang   ShapeAdaptor inputShape(adaptor.getInput().getType());
197809349303SJacques Pienaar   if (inputShape.hasRank()) {
197911dda1a2SRob Suderman     outputShape[0] = ShapedType::isDynamic(outputShape[0])
198009349303SJacques Pienaar                          ? inputShape.getDimSize(0)
198111dda1a2SRob Suderman                          : outputShape[0];
198209349303SJacques Pienaar     inputHeight = inputShape.getDimSize(1);
198309349303SJacques Pienaar     inputWidth = inputShape.getDimSize(2);
198411dda1a2SRob Suderman   }
198511dda1a2SRob Suderman 
198611dda1a2SRob Suderman   // Weight shapes describes the filter width/height and the output channels.
1987057fc8e7SAmanda Tang   ShapeAdaptor weightShape(adaptor.getFilter().getType());
198809349303SJacques Pienaar   if (weightShape.hasRank()) {
198911dda1a2SRob Suderman     outputShape[3] = ShapedType::isDynamic(outputShape[3])
199009349303SJacques Pienaar                          ? weightShape.getDimSize(0)
199111dda1a2SRob Suderman                          : outputShape[3];
199209349303SJacques Pienaar     weightHeight = weightShape.getDimSize(1);
199309349303SJacques Pienaar     weightWidth = weightShape.getDimSize(2);
199411dda1a2SRob Suderman   }
199511dda1a2SRob Suderman 
199611dda1a2SRob Suderman   // Bias shape can describe the output channels.
1997057fc8e7SAmanda Tang   ShapeAdaptor biasShape(adaptor.getInput().getType());
199809349303SJacques Pienaar   if (biasShape.hasRank()) {
199911dda1a2SRob Suderman     outputShape[3] = ShapedType::isDynamic(outputShape[3])
200009349303SJacques Pienaar                          ? biasShape.getDimSize(0)
200111dda1a2SRob Suderman                          : outputShape[3];
200211dda1a2SRob Suderman   }
200311dda1a2SRob Suderman 
200411030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
200511030c7dSAlexander Shaposhnikov   llvm::ArrayRef<int64_t> stride = adaptor.getStride();
200611dda1a2SRob Suderman 
200711dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputHeight) &&
200811dda1a2SRob Suderman       !ShapedType::isDynamic(weightHeight)) {
2009fb4cedccSAliia Khasanova     int64_t calculateSize =
2010fcbf3fafSRob Suderman         (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2011fb4cedccSAliia Khasanova     outputShape[1] =
2012fb4cedccSAliia Khasanova         ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
201311dda1a2SRob Suderman   }
201411dda1a2SRob Suderman 
201511dda1a2SRob Suderman   if (!ShapedType::isDynamic(inputWidth) &&
201611dda1a2SRob Suderman       !ShapedType::isDynamic(weightWidth)) {
2017fb4cedccSAliia Khasanova     int64_t calculateSize =
2018fcbf3fafSRob Suderman         (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2019fb4cedccSAliia Khasanova     outputShape[2] =
2020fb4cedccSAliia Khasanova         ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
202111dda1a2SRob Suderman   }
202211dda1a2SRob Suderman 
202311dda1a2SRob Suderman   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
202411dda1a2SRob Suderman   return success();
202511dda1a2SRob Suderman }
202611dda1a2SRob Suderman 
2027360a03c9SJack Frankland LogicalResult TransposeConv2DOp::verify() {
2028360a03c9SJack Frankland   if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2029360a03c9SJack Frankland     return failure();
2030360a03c9SJack Frankland   return success();
2031360a03c9SJack Frankland }
2032360a03c9SJack Frankland 
20331b00b94fSRob Suderman LogicalResult IfOp::inferReturnTypeComponents(
203422426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
2035057fc8e7SAmanda Tang     IfOp::Adaptor adaptor,
20361b00b94fSRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
20371b00b94fSRob Suderman   llvm::SmallVector<tosa::YieldOp> yieldOps;
2038057fc8e7SAmanda Tang   for (Region *region : adaptor.getRegions()) {
20391b00b94fSRob Suderman     for (auto &block : *region)
20401b00b94fSRob Suderman       if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
20411b00b94fSRob Suderman         yieldOps.push_back(returnOp);
20421b00b94fSRob Suderman   }
20431b00b94fSRob Suderman 
20441b00b94fSRob Suderman   if (yieldOps.empty())
20451b00b94fSRob Suderman     return failure();
20461b00b94fSRob Suderman 
20471b00b94fSRob Suderman   // Get the initial type information for the yield op.
20481b00b94fSRob Suderman   llvm::SmallVector<ValueKnowledge> resultKnowledge;
20491b00b94fSRob Suderman   resultKnowledge.reserve(yieldOps.front().getNumOperands());
20501b00b94fSRob Suderman   for (auto operand : yieldOps.front().getOperands()) {
20511b00b94fSRob Suderman     resultKnowledge.push_back(
20521b00b94fSRob Suderman         ValueKnowledge::getKnowledgeFromType(operand.getType()));
20531b00b94fSRob Suderman   }
20541b00b94fSRob Suderman 
20551b00b94fSRob Suderman   for (auto yieldOp : yieldOps) {
20561b00b94fSRob Suderman     if (resultKnowledge.size() != yieldOp.getNumOperands())
20571b00b94fSRob Suderman       return failure();
20581b00b94fSRob Suderman 
205989de9cc8SMehdi Amini     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
20601b00b94fSRob Suderman       int32_t index = it.index();
20611b00b94fSRob Suderman       auto meet = ValueKnowledge::meet(
20621b00b94fSRob Suderman           resultKnowledge[index],
20631b00b94fSRob Suderman           ValueKnowledge::getKnowledgeFromType(it.value().getType()));
20641b00b94fSRob Suderman       if (!meet)
20651b00b94fSRob Suderman         continue;
20661b00b94fSRob Suderman       resultKnowledge[index] = meet;
20671b00b94fSRob Suderman     }
20681b00b94fSRob Suderman   }
20691b00b94fSRob Suderman 
207009349303SJacques Pienaar   for (const ValueKnowledge &result : resultKnowledge) {
2071b0532286SRob Suderman     inferredReturnShapes.push_back(result.getShapedTypeComponents());
20721b00b94fSRob Suderman   }
2073b0532286SRob Suderman 
2074b0532286SRob Suderman   return success();
2075b0532286SRob Suderman }
2076b0532286SRob Suderman 
2077b0532286SRob Suderman LogicalResult WhileOp::inferReturnTypeComponents(
207822426110SRamkumar Ramachandra     MLIRContext *context, ::std::optional<Location> location,
2079057fc8e7SAmanda Tang     WhileOp::Adaptor adaptor,
2080b0532286SRob Suderman     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2081b0532286SRob Suderman   llvm::SmallVector<tosa::YieldOp> yieldOps;
2082057fc8e7SAmanda Tang   for (auto &block : adaptor.getBody())
2083b0532286SRob Suderman     if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2084b0532286SRob Suderman       yieldOps.push_back(returnOp);
2085b0532286SRob Suderman 
2086b0532286SRob Suderman   // TOSA's while must have a tosa.yield as its terminator. If not found this
2087b0532286SRob Suderman   // tosa.while is invalid.
2088b0532286SRob Suderman   if (yieldOps.empty())
2089b0532286SRob Suderman     return failure();
2090b0532286SRob Suderman 
2091b0532286SRob Suderman   // Get the initial type information from the operand types.
2092b0532286SRob Suderman   llvm::SmallVector<ValueKnowledge> resultKnowledge;
2093b0532286SRob Suderman   resultKnowledge.reserve(yieldOps.front().getNumOperands());
2094b0532286SRob Suderman   for (auto operand : yieldOps.front().getOperands()) {
2095b0532286SRob Suderman     resultKnowledge.push_back(
2096b0532286SRob Suderman         ValueKnowledge::getKnowledgeFromType(operand.getType()));
2097b0532286SRob Suderman   }
2098b0532286SRob Suderman 
2099b0532286SRob Suderman   for (auto yieldOp : yieldOps) {
2100b0532286SRob Suderman     if (resultKnowledge.size() != yieldOp.getNumOperands())
2101b0532286SRob Suderman       return failure();
2102b0532286SRob Suderman 
210389de9cc8SMehdi Amini     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2104b0532286SRob Suderman       int32_t index = it.index();
2105b0532286SRob Suderman       if (auto meet = ValueKnowledge::meet(
2106b0532286SRob Suderman               resultKnowledge[index],
2107b0532286SRob Suderman               ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
2108b0532286SRob Suderman         resultKnowledge[index] = meet;
21095161835dSlipracer       }
2110b0532286SRob Suderman     }
2111b0532286SRob Suderman   }
2112b0532286SRob Suderman 
2113b0532286SRob Suderman   for (const ValueKnowledge &result : resultKnowledge) {
2114b0532286SRob Suderman     inferredReturnShapes.push_back(result.getShapedTypeComponents());
21151b00b94fSRob Suderman   }
21161b00b94fSRob Suderman 
21171b00b94fSRob Suderman   return success();
21181b00b94fSRob Suderman }
21191b00b94fSRob Suderman 
212081566001SJakub Kuderski std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
2121c1fa60b4STres Popp   if (auto vt = llvm::dyn_cast<VectorType>(getType()))
212281566001SJakub Kuderski     return llvm::to_vector<4>(vt.getShape());
212381566001SJakub Kuderski   return std::nullopt;
212481566001SJakub Kuderski }
212581566001SJakub Kuderski 
2126da4d191fSTatWai Chong // parse and print of IfOp refer to the implementation of SCF dialect.
2127da4d191fSTatWai Chong ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2128da4d191fSTatWai Chong   // Create the regions for 'then'.
2129da4d191fSTatWai Chong   result.regions.reserve(2);
2130da4d191fSTatWai Chong   Region *thenRegion = result.addRegion();
2131da4d191fSTatWai Chong   Region *elseRegion = result.addRegion();
2132da4d191fSTatWai Chong 
2133da4d191fSTatWai Chong   auto &builder = parser.getBuilder();
2134da4d191fSTatWai Chong   OpAsmParser::UnresolvedOperand cond;
2135da4d191fSTatWai Chong   // Create a i1 tensor type for the boolean condition.
2136da4d191fSTatWai Chong   Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
2137da4d191fSTatWai Chong   if (parser.parseOperand(cond) ||
2138da4d191fSTatWai Chong       parser.resolveOperand(cond, i1Type, result.operands))
2139da4d191fSTatWai Chong     return failure();
2140da4d191fSTatWai Chong   // Parse optional results type list.
2141da4d191fSTatWai Chong   if (parser.parseOptionalArrowTypeList(result.types))
2142da4d191fSTatWai Chong     return failure();
2143da4d191fSTatWai Chong   // Parse the 'then' region.
2144da4d191fSTatWai Chong   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2145da4d191fSTatWai Chong     return failure();
2146da4d191fSTatWai Chong 
2147da4d191fSTatWai Chong   // If we find an 'else' keyword then parse the 'else' region.
2148da4d191fSTatWai Chong   if (!parser.parseOptionalKeyword("else")) {
2149da4d191fSTatWai Chong     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2150da4d191fSTatWai Chong       return failure();
2151da4d191fSTatWai Chong   }
2152da4d191fSTatWai Chong 
2153da4d191fSTatWai Chong   // Parse the optional attribute list.
2154da4d191fSTatWai Chong   if (parser.parseOptionalAttrDict(result.attributes))
2155da4d191fSTatWai Chong     return failure();
2156da4d191fSTatWai Chong   return success();
2157da4d191fSTatWai Chong }
2158da4d191fSTatWai Chong 
2159da4d191fSTatWai Chong void IfOp::print(OpAsmPrinter &p) {
2160da4d191fSTatWai Chong   bool printBlockTerminators = false;
2161da4d191fSTatWai Chong 
2162da4d191fSTatWai Chong   p << " " << getCond();
2163da4d191fSTatWai Chong   if (!getResults().empty()) {
2164da4d191fSTatWai Chong     p << " -> (" << getResultTypes() << ")";
2165da4d191fSTatWai Chong     // Print yield explicitly if the op defines values.
2166da4d191fSTatWai Chong     printBlockTerminators = true;
2167da4d191fSTatWai Chong   }
2168da4d191fSTatWai Chong   p << ' ';
2169da4d191fSTatWai Chong   p.printRegion(getThenBranch(),
2170da4d191fSTatWai Chong                 /*printEntryBlockArgs=*/false,
2171da4d191fSTatWai Chong                 /*printBlockTerminators=*/printBlockTerminators);
2172da4d191fSTatWai Chong 
2173da4d191fSTatWai Chong   // Print the 'else' regions if it exists and has a block.
2174da4d191fSTatWai Chong   auto &elseRegion = getElseBranch();
2175da4d191fSTatWai Chong   if (!elseRegion.empty()) {
2176da4d191fSTatWai Chong     p << " else ";
2177da4d191fSTatWai Chong     p.printRegion(elseRegion,
2178da4d191fSTatWai Chong                   /*printEntryBlockArgs=*/false,
2179da4d191fSTatWai Chong                   /*printBlockTerminators=*/printBlockTerminators);
2180da4d191fSTatWai Chong   }
2181da4d191fSTatWai Chong 
2182da4d191fSTatWai Chong   p.printOptionalAttrDict((*this)->getAttrs());
2183da4d191fSTatWai Chong }
2184da4d191fSTatWai Chong 
21856ed2d30dSFelix Schneider LogicalResult ReverseOp::verify() {
2186c6876b4eSJerry-Ge   TensorType inputType = getInput1().getType();
21876ed2d30dSFelix Schneider   TensorType outputType = getOutput().getType();
21886ed2d30dSFelix Schneider   int32_t reverseAxis = getAxis();
21896ed2d30dSFelix Schneider 
21906ed2d30dSFelix Schneider   if (reverseAxis < 0)
21916ed2d30dSFelix Schneider     return emitOpError("expected non-negative reverse axis");
21926ed2d30dSFelix Schneider   if (inputType.hasRank()) {
21936ed2d30dSFelix Schneider     int64_t inputRank = inputType.getRank();
21946ed2d30dSFelix Schneider     // We allow for a special case where the input/output shape has rank 0 and
21956ed2d30dSFelix Schneider     // axis is also 0.
21966ed2d30dSFelix Schneider     if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
21976ed2d30dSFelix Schneider       return emitOpError("expect input tensor rank (")
21986ed2d30dSFelix Schneider              << inputRank << ") to be larger than reverse axis (" << reverseAxis
21996ed2d30dSFelix Schneider              << ")";
22006ed2d30dSFelix Schneider   }
22016ed2d30dSFelix Schneider   if (outputType.hasRank()) {
22026ed2d30dSFelix Schneider     int64_t outputRank = outputType.getRank();
22036ed2d30dSFelix Schneider     if (inputType.hasRank() && outputRank != inputType.getRank())
22046ed2d30dSFelix Schneider       return emitOpError(
22056ed2d30dSFelix Schneider           "expect output tensor rank to be equal to input tensor rank");
22066ed2d30dSFelix Schneider     if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
22076ed2d30dSFelix Schneider       return emitOpError("expect output tensor rank (")
22086ed2d30dSFelix Schneider              << outputRank << ") to be larger than reverse axis ("
22096ed2d30dSFelix Schneider              << reverseAxis << ")";
22106ed2d30dSFelix Schneider   }
22116ed2d30dSFelix Schneider   return success();
22126ed2d30dSFelix Schneider }
22136ed2d30dSFelix Schneider 
2214da4d191fSTatWai Chong // parse and print of WhileOp refer to the implementation of SCF dialect.
2215da4d191fSTatWai Chong ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2216da4d191fSTatWai Chong   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2217da4d191fSTatWai Chong   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2218da4d191fSTatWai Chong   Region *cond = result.addRegion();
2219da4d191fSTatWai Chong   Region *body = result.addRegion();
2220da4d191fSTatWai Chong 
2221da4d191fSTatWai Chong   OptionalParseResult listResult =
2222da4d191fSTatWai Chong       parser.parseOptionalAssignmentList(regionArgs, operands);
2223da4d191fSTatWai Chong   if (listResult.has_value() && failed(listResult.value()))
2224da4d191fSTatWai Chong     return failure();
2225da4d191fSTatWai Chong 
2226da4d191fSTatWai Chong   FunctionType functionType;
2227da4d191fSTatWai Chong   SMLoc typeLoc = parser.getCurrentLocation();
2228da4d191fSTatWai Chong   if (failed(parser.parseColonType(functionType)))
2229da4d191fSTatWai Chong     return failure();
2230da4d191fSTatWai Chong 
2231da4d191fSTatWai Chong   result.addTypes(functionType.getResults());
2232da4d191fSTatWai Chong 
2233da4d191fSTatWai Chong   if (functionType.getNumInputs() != operands.size()) {
2234da4d191fSTatWai Chong     return parser.emitError(typeLoc)
2235da4d191fSTatWai Chong            << "expected as many input types as operands "
2236da4d191fSTatWai Chong            << "(expected " << operands.size() << " got "
2237da4d191fSTatWai Chong            << functionType.getNumInputs() << ")";
2238da4d191fSTatWai Chong   }
2239da4d191fSTatWai Chong 
2240da4d191fSTatWai Chong   // Resolve input operands.
2241da4d191fSTatWai Chong   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2242da4d191fSTatWai Chong                                     parser.getCurrentLocation(),
2243da4d191fSTatWai Chong                                     result.operands)))
2244da4d191fSTatWai Chong     return failure();
2245da4d191fSTatWai Chong 
2246da4d191fSTatWai Chong   // Propagate the types into the region arguments.
2247da4d191fSTatWai Chong   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2248da4d191fSTatWai Chong     regionArgs[i].type = functionType.getInput(i);
2249da4d191fSTatWai Chong 
2250da4d191fSTatWai Chong   return failure(parser.parseRegion(*cond, regionArgs) ||
2251da4d191fSTatWai Chong                  parser.parseKeyword("do") || parser.parseRegion(*body) ||
2252da4d191fSTatWai Chong                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
2253da4d191fSTatWai Chong }
2254da4d191fSTatWai Chong 
2255da4d191fSTatWai Chong static void printInitializationList(OpAsmPrinter &parser,
2256da4d191fSTatWai Chong                                     Block::BlockArgListType blocksArgs,
2257da4d191fSTatWai Chong                                     ValueRange initializers,
2258da4d191fSTatWai Chong                                     StringRef prefix = "") {
2259da4d191fSTatWai Chong   assert(blocksArgs.size() == initializers.size() &&
2260da4d191fSTatWai Chong          "expected same length of arguments and initializers");
2261da4d191fSTatWai Chong   if (initializers.empty())
2262da4d191fSTatWai Chong     return;
2263da4d191fSTatWai Chong 
2264da4d191fSTatWai Chong   parser << prefix << '(';
2265da4d191fSTatWai Chong   llvm::interleaveComma(
2266da4d191fSTatWai Chong       llvm::zip(blocksArgs, initializers), parser,
2267da4d191fSTatWai Chong       [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
2268da4d191fSTatWai Chong   parser << ")";
2269da4d191fSTatWai Chong }
2270da4d191fSTatWai Chong 
2271da4d191fSTatWai Chong void WhileOp::print(OpAsmPrinter &parser) {
2272da4d191fSTatWai Chong   printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2273da4d191fSTatWai Chong                           " ");
2274da4d191fSTatWai Chong   parser << " : ";
2275da4d191fSTatWai Chong   parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2276da4d191fSTatWai Chong   parser << ' ';
2277da4d191fSTatWai Chong   parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2278da4d191fSTatWai Chong   parser << " do ";
2279da4d191fSTatWai Chong   parser.printRegion(getBody());
2280da4d191fSTatWai Chong   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2281da4d191fSTatWai Chong }
2282da4d191fSTatWai Chong 
22838dea784bSRob Suderman //===----------------------------------------------------------------------===//
2284f09db6a3SJerry-Ge // TOSA Shape and Shape Operators Helper functions.
2285f09db6a3SJerry-Ge //===----------------------------------------------------------------------===//
2286f09db6a3SJerry-Ge 
2287f09db6a3SJerry-Ge bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
2288f09db6a3SJerry-Ge   return mlir::isa<tosa::shapeType>(t);
2289f09db6a3SJerry-Ge }
2290f09db6a3SJerry-Ge 
2291f09db6a3SJerry-Ge LogicalResult
2292f09db6a3SJerry-Ge mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
2293f09db6a3SJerry-Ge                               int rank) {
2294f09db6a3SJerry-Ge   if (rank < 0)
2295f09db6a3SJerry-Ge     return emitError() << "invalid rank (must be >= 0): " << rank;
2296f09db6a3SJerry-Ge   return success();
2297f09db6a3SJerry-Ge }
2298f09db6a3SJerry-Ge 
2299f09db6a3SJerry-Ge LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
2300f09db6a3SJerry-Ge   for (auto v : op->getOperands()) {
2301f09db6a3SJerry-Ge     if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
2302f09db6a3SJerry-Ge       Operation *definingOp = v.getDefiningOp();
2303f09db6a3SJerry-Ge       if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
2304f09db6a3SJerry-Ge         return op->emitOpError("shape operand is not compile time resolvable");
2305f09db6a3SJerry-Ge       }
2306f09db6a3SJerry-Ge     }
2307f09db6a3SJerry-Ge   }
2308f09db6a3SJerry-Ge   return success();
2309f09db6a3SJerry-Ge }
2310f09db6a3SJerry-Ge 
2311f09db6a3SJerry-Ge LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
2312f09db6a3SJerry-Ge   for (auto type : op->getOperandTypes()) {
2313f09db6a3SJerry-Ge     if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2314f09db6a3SJerry-Ge       return op->emitOpError("must have operands with tosa shape type");
2315f09db6a3SJerry-Ge     }
2316f09db6a3SJerry-Ge   }
2317f09db6a3SJerry-Ge   for (auto type : op->getResultTypes()) {
2318f09db6a3SJerry-Ge     if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2319f09db6a3SJerry-Ge       return op->emitOpError("must have result with tosa shape type");
2320f09db6a3SJerry-Ge     }
2321f09db6a3SJerry-Ge   }
2322f09db6a3SJerry-Ge   return success();
2323f09db6a3SJerry-Ge }
2324f09db6a3SJerry-Ge 
2325f09db6a3SJerry-Ge LogicalResult
2326f09db6a3SJerry-Ge OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
2327f09db6a3SJerry-Ge   if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
2328f09db6a3SJerry-Ge       failed(verifyTosaShapeOperator(op)))
2329f09db6a3SJerry-Ge     return failure();
2330f09db6a3SJerry-Ge 
2331f09db6a3SJerry-Ge   // delegate function that returns rank of shape type
2332f09db6a3SJerry-Ge   auto getRank = [](const Type type) {
2333f09db6a3SJerry-Ge     return mlir::cast<mlir::tosa::shapeType>(type).getRank();
2334f09db6a3SJerry-Ge   };
2335f09db6a3SJerry-Ge   auto operandTypes = op->getOperandTypes();
2336f09db6a3SJerry-Ge   auto resultTypes = op->getResultTypes();
2337f09db6a3SJerry-Ge 
2338f09db6a3SJerry-Ge   auto rank = getRank(*op->getOperandTypes().begin());
2339f09db6a3SJerry-Ge   for (auto type : operandTypes) {
2340f09db6a3SJerry-Ge     if (getRank(type) != rank) {
2341f09db6a3SJerry-Ge       return op->emitOpError("operands don't have matching ranks");
2342f09db6a3SJerry-Ge     }
2343f09db6a3SJerry-Ge   }
2344f09db6a3SJerry-Ge   for (auto type : resultTypes) {
2345f09db6a3SJerry-Ge     if (getRank(type) != rank) {
2346f09db6a3SJerry-Ge       return op->emitOpError("result shape has different rank than operands");
2347f09db6a3SJerry-Ge     }
2348f09db6a3SJerry-Ge   }
2349f09db6a3SJerry-Ge   return success();
2350f09db6a3SJerry-Ge }
2351f09db6a3SJerry-Ge 
2352f09db6a3SJerry-Ge //===----------------------------------------------------------------------===//
2353f09db6a3SJerry-Ge // TOSA Shape Operators verify functions.
2354f09db6a3SJerry-Ge //===----------------------------------------------------------------------===//
2355f09db6a3SJerry-Ge 
2356f09db6a3SJerry-Ge LogicalResult tosa::ConstShapeOp::verify() {
2357f09db6a3SJerry-Ge   // check that number of elements in value attr equal to rank of result shape
2358f09db6a3SJerry-Ge   auto count = getValue().getNumElements();
2359f09db6a3SJerry-Ge   auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
2360f09db6a3SJerry-Ge   if (!(count == rank || (count == 1 && rank == 0))) {
2361f09db6a3SJerry-Ge     return emitOpError("expect number of elements in attribute value (")
2362f09db6a3SJerry-Ge            << count << ") to be equal to the rank (" << rank
2363f09db6a3SJerry-Ge            << ") for the result shape type";
2364f09db6a3SJerry-Ge   }
2365f09db6a3SJerry-Ge   return success();
2366f09db6a3SJerry-Ge }
2367f09db6a3SJerry-Ge 
2368f09db6a3SJerry-Ge //===----------------------------------------------------------------------===//
2369f1182bd6SMogball // TOSA Attribute Definitions.
2370f1182bd6SMogball //===----------------------------------------------------------------------===//
2371f1182bd6SMogball 
2372f1182bd6SMogball #define GET_ATTRDEF_CLASSES
2373f1182bd6SMogball #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2374f1182bd6SMogball 
2375f1182bd6SMogball //===----------------------------------------------------------------------===//
2376f09db6a3SJerry-Ge // TOSA Type Definitions.
2377f09db6a3SJerry-Ge //===----------------------------------------------------------------------===//
2378f09db6a3SJerry-Ge #define GET_TYPEDEF_CLASSES
2379f09db6a3SJerry-Ge #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
2380f09db6a3SJerry-Ge 
2381f09db6a3SJerry-Ge //===----------------------------------------------------------------------===//
2382b2812113SSuraj Sudhir // TOSA Operator Definitions.
2383b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
2384b2812113SSuraj Sudhir 
2385b2812113SSuraj Sudhir #define GET_OP_CLASSES
2386b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
2387