1b2812113SSuraj Sudhir //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// 2b2812113SSuraj Sudhir // 3b2812113SSuraj Sudhir // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b2812113SSuraj Sudhir // See https://llvm.org/LICENSE.txt for license information. 5b2812113SSuraj Sudhir // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b2812113SSuraj Sudhir // 7b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 8b2812113SSuraj Sudhir // 9b2812113SSuraj Sudhir // \file 10b2812113SSuraj Sudhir // This file implements the TOSA Specification: 11b2812113SSuraj Sudhir // https://developer.mlplatform.org/w/tosa/ 12b2812113SSuraj Sudhir // 13b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 14b2812113SSuraj Sudhir 15b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16513cdb82SJustin Fargnoli #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" 17852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/Quant.h" 182b2ebb6fSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h" 19b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 201b00b94fSRob Suderman #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 218190369eSFelix Schneider #include "mlir/Dialect/Utils/IndexingUtils.h" 2209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 23f1182bd6SMogball #include "mlir/IR/DialectImplementation.h" 245a4e7760SRob Suderman #include "mlir/IR/Matchers.h" 252d0ba5e1SRob Suderman #include "mlir/IR/PatternMatch.h" 26b73e8325Snot-jenni #include "mlir/IR/TypeUtilities.h" 27057fc8e7SAmanda Tang #include "mlir/Interfaces/InferTypeOpInterface.h" 28b2812113SSuraj Sudhir #include "mlir/Transforms/InliningUtils.h" 29d89a0a65SAviad Cohen #include "llvm/ADT/APFloat.h" 30b0532286SRob Suderman #include "llvm/ADT/DenseMap.h" 31f1182bd6SMogball #include "llvm/ADT/TypeSwitch.h" 32b2812113SSuraj Sudhir 33e6eb94d3SLongsheng Mou #include <numeric> 34e6eb94d3SLongsheng Mou 35b2812113SSuraj Sudhir using namespace mlir; 36b2812113SSuraj Sudhir using namespace mlir::tosa; 37b2812113SSuraj Sudhir 38485cc55eSStella Laurenzo #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc" 397e622b61SJerry-Ge #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 40485cc55eSStella Laurenzo 41b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 42d883a02aSMogball // Tosa dialect interface includes. 43b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 44f1182bd6SMogball 45b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" 46b2812113SSuraj Sudhir 47b2812113SSuraj Sudhir namespace { 481a7e8b90SJacques Pienaar #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" 491a7e8b90SJacques Pienaar 50b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 51b2812113SSuraj Sudhir // Dialect Function Inliner Interface. 52b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 53b2812113SSuraj Sudhir struct TosaInlinerInterface : public DialectInlinerInterface { 54b2812113SSuraj Sudhir using DialectInlinerInterface::DialectInlinerInterface; 55b2812113SSuraj Sudhir 56b2812113SSuraj Sudhir //===--------------------------------------------------------------------===// 57b2812113SSuraj Sudhir // Analysis Hooks. 58b2812113SSuraj Sudhir //===--------------------------------------------------------------------===// 59b2812113SSuraj Sudhir 60b2812113SSuraj Sudhir /// All operations can be inlined by default. 61b2812113SSuraj Sudhir bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, 624d67b278SJeff Niu IRMapping &map) const final { 63b2812113SSuraj Sudhir return true; 64b2812113SSuraj Sudhir } 65b2812113SSuraj Sudhir 66b2812113SSuraj Sudhir /// All regions with If and While parent operators can be inlined. 67b2812113SSuraj Sudhir bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 684d67b278SJeff Niu IRMapping &map) const final { 69b2812113SSuraj Sudhir return (isa<tosa::IfOp>(dest->getParentOp()) || 70b2812113SSuraj Sudhir isa<tosa::WhileOp>(dest->getParentOp())); 71b2812113SSuraj Sudhir } 72b2812113SSuraj Sudhir }; 731a7e8b90SJacques Pienaar 741a7e8b90SJacques Pienaar /// This class implements the bytecode interface for the Tosa dialect. 751a7e8b90SJacques Pienaar struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { 761a7e8b90SJacques Pienaar TosaDialectBytecodeInterface(Dialect *dialect) 771a7e8b90SJacques Pienaar : BytecodeDialectInterface(dialect) {} 781a7e8b90SJacques Pienaar 791a7e8b90SJacques Pienaar //===--------------------------------------------------------------------===// 801a7e8b90SJacques Pienaar // Attributes 811a7e8b90SJacques Pienaar 821a7e8b90SJacques Pienaar Attribute readAttribute(DialectBytecodeReader &reader) const override { 831a7e8b90SJacques Pienaar return ::readAttribute(getContext(), reader); 841a7e8b90SJacques Pienaar } 851a7e8b90SJacques Pienaar 861a7e8b90SJacques Pienaar LogicalResult writeAttribute(Attribute attr, 871a7e8b90SJacques Pienaar DialectBytecodeWriter &writer) const override { 881a7e8b90SJacques Pienaar return ::writeAttribute(attr, writer); 891a7e8b90SJacques Pienaar } 901a7e8b90SJacques Pienaar 911a7e8b90SJacques Pienaar //===--------------------------------------------------------------------===// 921a7e8b90SJacques Pienaar // Types 931a7e8b90SJacques Pienaar 941a7e8b90SJacques Pienaar Type readType(DialectBytecodeReader &reader) const override { 951a7e8b90SJacques Pienaar return ::readType(getContext(), reader); 961a7e8b90SJacques Pienaar } 971a7e8b90SJacques Pienaar 981a7e8b90SJacques Pienaar LogicalResult writeType(Type type, 991a7e8b90SJacques Pienaar DialectBytecodeWriter &writer) const override { 1001a7e8b90SJacques Pienaar return ::writeType(type, writer); 1011a7e8b90SJacques Pienaar } 1021a7e8b90SJacques Pienaar 1031a7e8b90SJacques Pienaar void writeVersion(DialectBytecodeWriter &writer) const final { 1041a7e8b90SJacques Pienaar // TODO: Populate. 1051a7e8b90SJacques Pienaar } 1061a7e8b90SJacques Pienaar 1071a7e8b90SJacques Pienaar std::unique_ptr<DialectVersion> 1081a7e8b90SJacques Pienaar readVersion(DialectBytecodeReader &reader) const final { 1091a7e8b90SJacques Pienaar // TODO: Populate 1101a7e8b90SJacques Pienaar reader.emitError("Dialect does not support versioning"); 1111a7e8b90SJacques Pienaar return nullptr; 1121a7e8b90SJacques Pienaar } 1131a7e8b90SJacques Pienaar 1141a7e8b90SJacques Pienaar LogicalResult upgradeFromVersion(Operation *topLevelOp, 115d2f06769SMehdi Amini const DialectVersion &version) const final { 1161a7e8b90SJacques Pienaar return success(); 1171a7e8b90SJacques Pienaar } 1181a7e8b90SJacques Pienaar }; 1191a7e8b90SJacques Pienaar 120be0a7e9fSMehdi Amini } // namespace 121b2812113SSuraj Sudhir 122b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 123b2812113SSuraj Sudhir // TOSA control flow support. 124b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 125b2812113SSuraj Sudhir 126b2812113SSuraj Sudhir /// Returns the while loop body. 1279b5ef2beSMatthias Springer SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; } 128b2812113SSuraj Sudhir 129b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 130b2812113SSuraj Sudhir // Tosa dialect initialization. 131b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 132b2812113SSuraj Sudhir 133b2812113SSuraj Sudhir void TosaDialect::initialize() { 134f09db6a3SJerry-Ge addTypes< 135f09db6a3SJerry-Ge #define GET_TYPEDEF_LIST 136f09db6a3SJerry-Ge #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" 137f09db6a3SJerry-Ge >(); 138b2812113SSuraj Sudhir addOperations< 139b2812113SSuraj Sudhir #define GET_OP_LIST 140b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" 141b2812113SSuraj Sudhir >(); 142f1182bd6SMogball addAttributes< 143f1182bd6SMogball #define GET_ATTRDEF_LIST 144f1182bd6SMogball #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" 145f1182bd6SMogball >(); 1461a7e8b90SJacques Pienaar addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>(); 147513cdb82SJustin Fargnoli declarePromisedInterfaces< 148513cdb82SJustin Fargnoli mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, 14982383d5fSTai Ly ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp, 150513cdb82SJustin Fargnoli LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, 151513cdb82SJustin Fargnoli LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, 152513cdb82SJustin Fargnoli BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp, 153513cdb82SJustin Fargnoli NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, 154513cdb82SJustin Fargnoli GreaterEqualOp, MatMulOp>(); 155b2812113SSuraj Sudhir } 156b2812113SSuraj Sudhir 157db9713cdSStella Laurenzo Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, 158db9713cdSStella Laurenzo Type type, Location loc) { 159db9713cdSStella Laurenzo // Tosa dialect constants only support ElementsAttr unlike standard dialect 160db9713cdSStella Laurenzo // constant which supports all attributes. 161f09db6a3SJerry-Ge if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) { 162f09db6a3SJerry-Ge return builder.create<tosa::ConstShapeOp>( 163f09db6a3SJerry-Ge loc, type, llvm::cast<DenseIntElementsAttr>(value)); 164f09db6a3SJerry-Ge } 165c1fa60b4STres Popp if (llvm::isa<ElementsAttr>(value)) 166c1fa60b4STres Popp return builder.create<tosa::ConstOp>(loc, type, 167c1fa60b4STres Popp llvm::cast<ElementsAttr>(value)); 168db9713cdSStella Laurenzo return nullptr; 169db9713cdSStella Laurenzo } 170db9713cdSStella Laurenzo 171db9713cdSStella Laurenzo //===----------------------------------------------------------------------===// 172af972f01STai Ly // Parsers and printers 173af972f01STai Ly //===----------------------------------------------------------------------===// 174af972f01STai Ly 175af972f01STai Ly ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, 176af972f01STai Ly Attribute &attr) { 177af972f01STai Ly if (succeeded(parser.parseOptionalEqual())) { 178af972f01STai Ly if (failed(parser.parseAttribute(attr))) { 179af972f01STai Ly return parser.emitError(parser.getCurrentLocation()) 180af972f01STai Ly << "expected attribute"; 181af972f01STai Ly } 182a5757c5bSChristian Sigg if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { 183af972f01STai Ly typeAttr = TypeAttr::get(typedAttr.getType()); 184af972f01STai Ly } 185af972f01STai Ly return success(); 186af972f01STai Ly } 187af972f01STai Ly 188af972f01STai Ly Type type; 189af972f01STai Ly if (failed(parser.parseColonType(type))) { 190af972f01STai Ly return parser.emitError(parser.getCurrentLocation()) << "expected type"; 191af972f01STai Ly } 192af972f01STai Ly typeAttr = TypeAttr::get(type); 193af972f01STai Ly 194af972f01STai Ly return success(); 195af972f01STai Ly } 196af972f01STai Ly 197af972f01STai Ly void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, 198af972f01STai Ly Attribute attr) { 199af972f01STai Ly bool needsSpace = false; 200a5757c5bSChristian Sigg auto typedAttr = dyn_cast_or_null<TypedAttr>(attr); 201af972f01STai Ly if (!typedAttr || typedAttr.getType() != type.getValue()) { 202af972f01STai Ly p << ": "; 203af972f01STai Ly p.printAttribute(type); 204af972f01STai Ly needsSpace = true; // subsequent attr value needs a space separator 205af972f01STai Ly } 206af972f01STai Ly if (attr) { 207af972f01STai Ly if (needsSpace) 208af972f01STai Ly p << ' '; 209af972f01STai Ly p << "= "; 210af972f01STai Ly p.printAttribute(attr); 211af972f01STai Ly } 212af972f01STai Ly } 213af972f01STai Ly 214af972f01STai Ly //===----------------------------------------------------------------------===// 215b2812113SSuraj Sudhir // TOSA Operator Verifiers. 216b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 217b2812113SSuraj Sudhir 218d2353695SPeiming Liu template <typename T> 219d2353695SPeiming Liu static LogicalResult verifyConvOp(T op) { 220b2812113SSuraj Sudhir // All TOSA conv ops have an input() and weight(). 221c1fa60b4STres Popp auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); 222360a03c9SJack Frankland 223360a03c9SJack Frankland RankedTensorType weightType; 224360a03c9SJack Frankland if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) 225360a03c9SJack Frankland weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType()); 226360a03c9SJack Frankland else 227360a03c9SJack Frankland weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); 228b2812113SSuraj Sudhir 229b2812113SSuraj Sudhir // Must be ranked tensor types 230afb05823SMehdi Amini if (!inputType) { 23113448db0SJacques Pienaar op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); 232b2812113SSuraj Sudhir return failure(); 233afb05823SMehdi Amini } 234afb05823SMehdi Amini if (!weightType) { 235360a03c9SJack Frankland if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) { 236360a03c9SJack Frankland op.emitOpError("expect a ranked tensor for filter, got ") 237360a03c9SJack Frankland << op.getFilter(); 238360a03c9SJack Frankland } else { 239360a03c9SJack Frankland op.emitOpError("expect a ranked tensor for weight, got ") 240360a03c9SJack Frankland << op.getWeight(); 241360a03c9SJack Frankland } 242afb05823SMehdi Amini return failure(); 243afb05823SMehdi Amini } 244b2812113SSuraj Sudhir 245c7676d99SRob Suderman auto inputEType = inputType.getElementType(); 246c7676d99SRob Suderman auto weightEType = weightType.getElementType(); 247c7676d99SRob Suderman 248c1fa60b4STres Popp bool inputIsQuant = !llvm::isa<FloatType>(inputEType); 249c1fa60b4STres Popp bool weightIsQuant = !llvm::isa<FloatType>(weightEType); 250b2812113SSuraj Sudhir 251b2812113SSuraj Sudhir // Either both must be quantized or both unquantized. 252afb05823SMehdi Amini if (inputIsQuant != weightIsQuant) { 253afb05823SMehdi Amini op.emitOpError( 254afb05823SMehdi Amini "expect both input and weight to be float or not together, got ") 255afb05823SMehdi Amini << inputEType << " and " << weightEType; 256b2812113SSuraj Sudhir return failure(); 257afb05823SMehdi Amini } 258b2812113SSuraj Sudhir 259b2812113SSuraj Sudhir // Quantized type must have constructed the quantizationattr, and unquantized 260b2812113SSuraj Sudhir // types should not have a quantizationattr. 26113448db0SJacques Pienaar if ((inputIsQuant && !op.getQuantizationInfo()) || 26213448db0SJacques Pienaar (!inputIsQuant && op.getQuantizationInfo())) { 263afb05823SMehdi Amini op.emitOpError("quantizationattr is required for quantized type, and not " 264afb05823SMehdi Amini "allowed for float type"); 265b2812113SSuraj Sudhir return failure(); 266afb05823SMehdi Amini } 267a54efdbdSArteen Abrishami return success(); 268a54efdbdSArteen Abrishami } 269a54efdbdSArteen Abrishami 270a54efdbdSArteen Abrishami LogicalResult tosa::ConstOp::verify() { 271a54efdbdSArteen Abrishami 272a54efdbdSArteen Abrishami auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType()); 273a54efdbdSArteen Abrishami auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType()); 274a54efdbdSArteen Abrishami 275a54efdbdSArteen Abrishami if (!attrType || !outputType) { 276a54efdbdSArteen Abrishami emitOpError("expected tensors for attr/result type"); 277a54efdbdSArteen Abrishami return failure(); 278a54efdbdSArteen Abrishami } 279a54efdbdSArteen Abrishami 280a54efdbdSArteen Abrishami if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>( 281a54efdbdSArteen Abrishami outputType.getElementType())) { 282a54efdbdSArteen Abrishami if (result.getStorageType() == attrType.getElementType()) 283a54efdbdSArteen Abrishami return success(); 284a54efdbdSArteen Abrishami } 285a54efdbdSArteen Abrishami 286a54efdbdSArteen Abrishami if (attrType.getElementType() != outputType.getElementType()) { 287a54efdbdSArteen Abrishami emitOpError("expected same attr/result element types"); 288a54efdbdSArteen Abrishami return failure(); 289a54efdbdSArteen Abrishami } 290b2812113SSuraj Sudhir 291b2812113SSuraj Sudhir return success(); 292b2812113SSuraj Sudhir } 293b2812113SSuraj Sudhir 294360a03c9SJack Frankland template <typename T> 295360a03c9SJack Frankland static LogicalResult verifyConvOpModes(T op) { 296360a03c9SJack Frankland auto inputEType = 297360a03c9SJack Frankland llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); 298360a03c9SJack Frankland 299360a03c9SJack Frankland if (auto quantType = 300360a03c9SJack Frankland llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) 301360a03c9SJack Frankland inputEType = quantType.getStorageType(); 302360a03c9SJack Frankland 303360a03c9SJack Frankland auto accType = op.getAccType(); 304360a03c9SJack Frankland if (inputEType.isInteger(8) && !accType.isInteger(32)) 305360a03c9SJack Frankland return op.emitOpError("accumulator type for i8 tensor is not i32"); 306360a03c9SJack Frankland 307360a03c9SJack Frankland if (inputEType.isInteger(16) && !accType.isInteger(48)) 308360a03c9SJack Frankland return op.emitOpError("accumulator type for i16 tensor is not i48"); 309360a03c9SJack Frankland 3107a77f14cSMatthias Springer if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16()) 311360a03c9SJack Frankland return op.emitOpError("accumulator type for f8 tensor is not f16"); 312360a03c9SJack Frankland 313360a03c9SJack Frankland if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) 314360a03c9SJack Frankland return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); 315360a03c9SJack Frankland 316360a03c9SJack Frankland if (inputEType.isBF16() && !accType.isF32()) 317360a03c9SJack Frankland return op.emitOpError("accumulator type for bf16 tensor is not f32"); 318360a03c9SJack Frankland 319360a03c9SJack Frankland if (inputEType.isF32() && !accType.isF32()) 320360a03c9SJack Frankland return op.emitOpError("accumulator type for f32 tensor is not f32"); 321360a03c9SJack Frankland 322360a03c9SJack Frankland return success(); 323360a03c9SJack Frankland } 324360a03c9SJack Frankland 325414709eeSGeorgios Pinitas LogicalResult tosa::ArgMaxOp::verify() { 326414709eeSGeorgios Pinitas // Ensure output is of 32-bit integer 327414709eeSGeorgios Pinitas const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); 328414709eeSGeorgios Pinitas if (!resultETy.isIntOrIndex()) 329414709eeSGeorgios Pinitas return emitOpError("result tensor is not of integer type"); 330414709eeSGeorgios Pinitas 331414709eeSGeorgios Pinitas // Ensure axis is within the tensor rank 332414709eeSGeorgios Pinitas const auto inputType = llvm::cast<ShapedType>(getInput().getType()); 333414709eeSGeorgios Pinitas const int64_t axis = getAxisAttr().getInt(); 334414709eeSGeorgios Pinitas if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) 335414709eeSGeorgios Pinitas return emitOpError("specified axis is outside the rank of the tensor"); 336414709eeSGeorgios Pinitas 337414709eeSGeorgios Pinitas return success(); 338414709eeSGeorgios Pinitas } 339414709eeSGeorgios Pinitas 3401be88f5aSRiver Riddle LogicalResult tosa::AvgPool2dOp::verify() { 341b76d8f7dSKai Sasaki auto inputType = llvm::cast<ShapedType>(getInput().getType()); 342b76d8f7dSKai Sasaki 343b76d8f7dSKai Sasaki auto inputETy = inputType.getElementType(); 344c1fa60b4STres Popp auto resultETy = llvm::cast<ShapedType>(getType()).getElementType(); 34595e4b715SRob Suderman 346c1fa60b4STres Popp if (auto quantType = 347c1fa60b4STres Popp llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) 34895e4b715SRob Suderman inputETy = quantType.getStorageType(); 34995e4b715SRob Suderman 350c1fa60b4STres Popp if (auto quantType = 351c1fa60b4STres Popp llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy)) 35295e4b715SRob Suderman resultETy = quantType.getStorageType(); 35395e4b715SRob Suderman 35426a7f423STatWai Chong auto accType = getAccType(); 355d46a135dSTres Popp if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32)) 35626a7f423STatWai Chong return emitOpError("accumulator type for integer tensor is not i32"); 35726a7f423STatWai Chong 3582c9ddfc7Sfabrizio-indirli if (inputETy.isF16() && !(accType.isF16() || accType.isF32())) 3592c9ddfc7Sfabrizio-indirli return emitOpError("accumulator type for f16 tensor is not f16/f32"); 3602c9ddfc7Sfabrizio-indirli 3612c9ddfc7Sfabrizio-indirli if (inputETy.isBF16() && !accType.isF32()) 3622c9ddfc7Sfabrizio-indirli return emitOpError("accumulator type for bf16 tensor is not f32"); 36326a7f423STatWai Chong 36426a7f423STatWai Chong if (inputETy.isF32() && !accType.isF32()) 36526a7f423STatWai Chong return emitOpError("accumulator type for f32 tensor is not f32"); 36626a7f423STatWai Chong 3672c9ddfc7Sfabrizio-indirli if ((inputETy.isF32() && resultETy.isF32()) || 3682c9ddfc7Sfabrizio-indirli (inputETy.isF16() && resultETy.isF16()) || 3692c9ddfc7Sfabrizio-indirli (inputETy.isBF16() && resultETy.isBF16()) || 3702c9ddfc7Sfabrizio-indirli (inputETy.isInteger(8) && resultETy.isInteger(8)) || 3712c9ddfc7Sfabrizio-indirli (inputETy.isInteger(16) && resultETy.isInteger(16))) 37295e4b715SRob Suderman return success(); 37395e4b715SRob Suderman 3741be88f5aSRiver Riddle return emitOpError("input/output element types are incompatible."); 37595e4b715SRob Suderman } 37695e4b715SRob Suderman 377dde7b80eSfabrizio-indirli LogicalResult tosa::ClampOp::verify() { 378dde7b80eSfabrizio-indirli mlir::Type inputETy = 379dde7b80eSfabrizio-indirli llvm::cast<ShapedType>(getInput().getType()).getElementType(); 380a4803d8aSAdrian Kuegel if (auto quantType = 381a4803d8aSAdrian Kuegel llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { 382a4803d8aSAdrian Kuegel inputETy = quantType.getStorageType(); 383a4803d8aSAdrian Kuegel } 384dde7b80eSfabrizio-indirli mlir::Type maxFpType = getMaxFpAttr().getType(); 385dde7b80eSfabrizio-indirli mlir::Type minFpType = getMinFpAttr().getType(); 386dde7b80eSfabrizio-indirli mlir::Type outputETy = 387dde7b80eSfabrizio-indirli llvm::cast<ShapedType>(getOutput().getType()).getElementType(); 388a4803d8aSAdrian Kuegel if (auto quantType = 389a4803d8aSAdrian Kuegel llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { 390a4803d8aSAdrian Kuegel outputETy = quantType.getStorageType(); 391a4803d8aSAdrian Kuegel } 392dde7b80eSfabrizio-indirli unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); 393dde7b80eSfabrizio-indirli 394dde7b80eSfabrizio-indirli if (inputETy != outputETy) 395dde7b80eSfabrizio-indirli return emitOpError("input/output element types are incompatible."); 396dde7b80eSfabrizio-indirli 397a54efdbdSArteen Abrishami // If input datatype is float, check that the two min/max_fp attributes 398a54efdbdSArteen Abrishami // share the same type and that their type is either the same of the input's 399a54efdbdSArteen Abrishami // datatype, or a float type whose bitwidth > input datatype bitwidth. 400dde7b80eSfabrizio-indirli if (!inputETy.isInteger(dataTypeBitWidth)) { 401dde7b80eSfabrizio-indirli if (((maxFpType != minFpType) || 402dde7b80eSfabrizio-indirli (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= 403dde7b80eSfabrizio-indirli inputETy.getIntOrFloatBitWidth()))) 404dde7b80eSfabrizio-indirli return emitOpError("min/max attributes types are incompatible with " 405dde7b80eSfabrizio-indirli "input/output element types."); 406dde7b80eSfabrizio-indirli } 407dde7b80eSfabrizio-indirli 408dde7b80eSfabrizio-indirli return success(); 409dde7b80eSfabrizio-indirli } 410dde7b80eSfabrizio-indirli 411b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 412b2812113SSuraj Sudhir // TOSA Operator Quantization Builders. 413b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 414b2812113SSuraj Sudhir 415b2812113SSuraj Sudhir /// This builder is called on all convolution operators except TransposeConv, 416b2812113SSuraj Sudhir /// which has specialized output shape semantics. The builder also defines the 417b2812113SSuraj Sudhir /// bitwidth of the output given the bit width of the input & weight content. 418ac3587f2SStella Laurenzo static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, 419b2812113SSuraj Sudhir Type outputType, Value input, Value weight, 42011030c7dSAlexander Shaposhnikov Value bias, DenseI64ArrayAttr pad, 42111030c7dSAlexander Shaposhnikov DenseI64ArrayAttr stride, 422360a03c9SJack Frankland DenseI64ArrayAttr dilation, 423360a03c9SJack Frankland TypeAttr accType) { 424b2812113SSuraj Sudhir 425b2812113SSuraj Sudhir result.addOperands({input, weight, bias}); 426b2812113SSuraj Sudhir result.addAttribute("pad", pad); 427b2812113SSuraj Sudhir result.addAttribute("stride", stride); 428b2812113SSuraj Sudhir result.addAttribute("dilation", dilation); 429360a03c9SJack Frankland result.addAttribute("acc_type", accType); 430b2812113SSuraj Sudhir 431b2812113SSuraj Sudhir auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); 432b2812113SSuraj Sudhir if (quantAttr) { 433b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 434b2812113SSuraj Sudhir result.addTypes( 435b2812113SSuraj Sudhir buildConvOpResultTypeInfo(builder, outputType, input, weight)); 436b2812113SSuraj Sudhir } else { 437b2812113SSuraj Sudhir result.addTypes(outputType); 438b2812113SSuraj Sudhir } 439b2812113SSuraj Sudhir } 440b2812113SSuraj Sudhir 441a54efdbdSArteen Abrishami /// Handles tosa.transpose_conv2d which has outpad and output shape 442a54efdbdSArteen Abrishami /// attributes. 44311030c7dSAlexander Shaposhnikov static void buildTransConvOpWithQuantInfo( 44411030c7dSAlexander Shaposhnikov OpBuilder &builder, OperationState &result, Type outputType, Value input, 44511030c7dSAlexander Shaposhnikov Value weight, Value bias, DenseI64ArrayAttr outpad, 446360a03c9SJack Frankland DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) { 447b2812113SSuraj Sudhir result.addOperands({input, weight, bias}); 448b2812113SSuraj Sudhir result.addAttribute("out_pad", outpad); 449b2812113SSuraj Sudhir result.addAttribute("stride", stride); 450b2812113SSuraj Sudhir result.addAttribute("out_shape", outputShape); 451360a03c9SJack Frankland result.addAttribute("acc_type", accType); 452b2812113SSuraj Sudhir auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); 453b2812113SSuraj Sudhir 454b2812113SSuraj Sudhir if (quantAttr) { 455b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 456b2812113SSuraj Sudhir result.addTypes( 457b2812113SSuraj Sudhir buildConvOpResultTypeInfo(builder, outputType, input, weight)); 458b2812113SSuraj Sudhir } else { 459b2812113SSuraj Sudhir result.addTypes(outputType); 460b2812113SSuraj Sudhir } 461b2812113SSuraj Sudhir } 462b2812113SSuraj Sudhir 463b2812113SSuraj Sudhir /// The tosa.fully_connected op has its own builder as it does not have 464b2812113SSuraj Sudhir /// strides/dilation/padding. 465ac3587f2SStella Laurenzo static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, 466b2812113SSuraj Sudhir Type outputType, Value input, Value weight, 467b2812113SSuraj Sudhir Value bias) { 468b2812113SSuraj Sudhir 469b2812113SSuraj Sudhir result.addOperands({input, weight, bias}); 470b2812113SSuraj Sudhir auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); 471b2812113SSuraj Sudhir if (quantAttr) { 472b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 473b2812113SSuraj Sudhir result.addTypes( 474b2812113SSuraj Sudhir buildConvOpResultTypeInfo(builder, outputType, input, weight)); 475b2812113SSuraj Sudhir } else { 476b2812113SSuraj Sudhir result.addTypes(outputType); 477b2812113SSuraj Sudhir } 478b2812113SSuraj Sudhir } 479b2812113SSuraj Sudhir 480a54efdbdSArteen Abrishami /// The tosa.matmul op is also intended to be generated where a 481a54efdbdSArteen Abrishami /// fully_connected op must be constructed where the weight is not a constant. 482a54efdbdSArteen Abrishami /// In this case, the fully_connected op must be expressed using matmul. 483b2812113SSuraj Sudhir /// TODO: Add link to the leglization document explaining this. 484ac3587f2SStella Laurenzo static void buildMatMulOpWithQuantInfo(OpBuilder &builder, 485ac3587f2SStella Laurenzo OperationState &result, Type outputType, 486ac3587f2SStella Laurenzo Value a, Value b) { 487b2812113SSuraj Sudhir result.addOperands({a, b}); 488b2812113SSuraj Sudhir auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); 489b2812113SSuraj Sudhir 490b2812113SSuraj Sudhir if (quantAttr) { 491b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 492b2812113SSuraj Sudhir 493c1fa60b4STres Popp auto inputType = llvm::dyn_cast<ShapedType>(a.getType()); 4948662a2f2SRob Suderman assert(inputType && "Input must be a shaped tensor type!"); 495b2812113SSuraj Sudhir 496c1fa60b4STres Popp auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( 497c1fa60b4STres Popp inputType.getElementType()); 498b2812113SSuraj Sudhir assert(inputQType && "Tensor must have quantized datatype!"); 499b2812113SSuraj Sudhir 500b2812113SSuraj Sudhir unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); 501b2812113SSuraj Sudhir 502c1fa60b4STres Popp auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType); 5038662a2f2SRob Suderman assert(outputShapedType && "Output must be a shaped type"); 504b2812113SSuraj Sudhir 505b2812113SSuraj Sudhir IntegerType accElementType; 506b2812113SSuraj Sudhir if (inputBits == 16) 507b2812113SSuraj Sudhir accElementType = builder.getIntegerType(48); 508b2812113SSuraj Sudhir else 509b2812113SSuraj Sudhir accElementType = builder.getI32Type(); 5108662a2f2SRob Suderman auto accType = outputShapedType.clone(accElementType); 511b2812113SSuraj Sudhir result.addTypes(accType); 512b2812113SSuraj Sudhir } else { 513b2812113SSuraj Sudhir result.addTypes(outputType); 514b2812113SSuraj Sudhir } 515b2812113SSuraj Sudhir } 516b2812113SSuraj Sudhir 517a54efdbdSArteen Abrishami /// Both the tosa.avg_pool2d and unary ops use the same 518a54efdbdSArteen Abrishami /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it 519a54efdbdSArteen Abrishami /// has additional parameters not part of the unary ops. 52026a7f423STatWai Chong static void 52126a7f423STatWai Chong buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, 52226a7f423STatWai Chong Type outputType, Value input, 52326a7f423STatWai Chong DenseArrayAttr kernel, DenseArrayAttr stride, 524d2f06769SMehdi Amini DenseArrayAttr pad, TypeAttr accType) { 525b2812113SSuraj Sudhir result.addOperands(input); 526b2812113SSuraj Sudhir result.addAttribute("kernel", kernel); 527b2812113SSuraj Sudhir result.addAttribute("stride", stride); 528b2812113SSuraj Sudhir result.addAttribute("pad", pad); 529d2f06769SMehdi Amini result.addAttribute("acc_type", accType); 530b2812113SSuraj Sudhir auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); 531b2812113SSuraj Sudhir if (quantAttr) 532b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 533b2812113SSuraj Sudhir result.types.push_back(outputType); 534b2812113SSuraj Sudhir } 535b2812113SSuraj Sudhir 536b2812113SSuraj Sudhir /// This builder is called on single-parameter unary operators that have scale 537b2812113SSuraj Sudhir /// relationship between their input and output, expressed by the 538b2812113SSuraj Sudhir /// UnaryOpQuantizationAttr. 539ac3587f2SStella Laurenzo static void buildUnaryOpWithQuantInfo(OpBuilder &builder, 540ac3587f2SStella Laurenzo OperationState &result, Type outputType, 541ac3587f2SStella Laurenzo Value input) { 542b2812113SSuraj Sudhir result.addOperands(input); 543b2812113SSuraj Sudhir auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); 544b2812113SSuraj Sudhir if (quantAttr) 545b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 546b2812113SSuraj Sudhir result.types.push_back(outputType); 547b2812113SSuraj Sudhir } 548b2812113SSuraj Sudhir 549b2812113SSuraj Sudhir /// This builder is called on TOSA pad operator that needs to create its own 550b2812113SSuraj Sudhir /// OptionalAttr quantization_attr parameter to scale the padding values 55182568021SSuraj Sudhir /// correctly. No pad_const is interpreted as zero-padding. 552ac3587f2SStella Laurenzo static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, 553ac3587f2SStella Laurenzo Type outputType, Value input, 554ac3587f2SStella Laurenzo Value paddings) { 555b2812113SSuraj Sudhir result.addOperands({input, paddings}); 556b2812113SSuraj Sudhir auto quantAttr = buildPadOpQuantizationAttr(builder, input); 557b2812113SSuraj Sudhir if (quantAttr) 558b2812113SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 559b2812113SSuraj Sudhir result.types.push_back(outputType); 560b2812113SSuraj Sudhir } 561b2812113SSuraj Sudhir 56282568021SSuraj Sudhir /// This builder is called on TOSA pad operator when an explicit pad_const 56382568021SSuraj Sudhir /// value is passed in. It also optionally constructs quantization_attr. 56482568021SSuraj Sudhir static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, 56582568021SSuraj Sudhir OperationState &result, 56682568021SSuraj Sudhir Type outputType, Value input, 56782568021SSuraj Sudhir Value paddings, 56802b6fb21SMehdi Amini Value padConst) { 56902b6fb21SMehdi Amini result.addOperands({input, paddings, padConst}); 57082568021SSuraj Sudhir auto quantAttr = buildPadOpQuantizationAttr(builder, input); 57182568021SSuraj Sudhir if (quantAttr) 57282568021SSuraj Sudhir result.addAttribute("quantization_info", quantAttr); 57382568021SSuraj Sudhir result.types.push_back(outputType); 57482568021SSuraj Sudhir } 57582568021SSuraj Sudhir 576b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 5778dea784bSRob Suderman // TOSA Operator Return Type Inference. 5788dea784bSRob Suderman //===----------------------------------------------------------------------===// 5798dea784bSRob Suderman 580b73e8325Snot-jenni static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, 581b73e8325Snot-jenni SmallVector<int64_t> &outShape) { 582b73e8325Snot-jenni int64_t outRank = 0; 583b73e8325Snot-jenni for (int i = 0, e = operands.size(); i != e; ++i) { 584b73e8325Snot-jenni auto shape = operands.getShape(i); 585b73e8325Snot-jenni if (!shape.hasRank()) { 586a54efdbdSArteen Abrishami // TODO(jennik): Update function to have better case handling for 587a54efdbdSArteen Abrishami // invalid operands and for ranked tensors. 588b73e8325Snot-jenni return failure(); 589b73e8325Snot-jenni } 590b73e8325Snot-jenni outRank = std::max<int64_t>(outRank, shape.getRank()); 591b73e8325Snot-jenni } 592b73e8325Snot-jenni 593b73e8325Snot-jenni outShape.resize(outRank, 1); 594b73e8325Snot-jenni 595b73e8325Snot-jenni for (int i = 0, e = operands.size(); i != e; ++i) { 596b73e8325Snot-jenni auto shape = operands.getShape(i); 597b73e8325Snot-jenni auto rankDiff = outShape.size() - shape.getRank(); 598b73e8325Snot-jenni 599b73e8325Snot-jenni for (size_t i = 0, e = shape.getRank(); i < e; ++i) { 600b73e8325Snot-jenni auto dim1 = outShape[i + rankDiff]; 601b73e8325Snot-jenni auto dim2 = shape.getDimSize(i); 602b73e8325Snot-jenni auto resolvedDim = dim1; 603b73e8325Snot-jenni 604b73e8325Snot-jenni if (dim1 == 1) { 605b73e8325Snot-jenni resolvedDim = dim2; 606b73e8325Snot-jenni } else if (dim2 == 1) { 607b73e8325Snot-jenni resolvedDim = dim1; 608b73e8325Snot-jenni } else if (dim1 != dim2) { 609b73e8325Snot-jenni return failure(); 610b73e8325Snot-jenni } 611b73e8325Snot-jenni outShape[i + rankDiff] = resolvedDim; 612b73e8325Snot-jenni } 613b73e8325Snot-jenni } 614b73e8325Snot-jenni 615b73e8325Snot-jenni return success(); 616b73e8325Snot-jenni } 617b73e8325Snot-jenni 6185a4e7760SRob Suderman LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( 61922426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 620057fc8e7SAmanda Tang ArgMaxOp::Adaptor adaptor, 6215a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 622057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 623057fc8e7SAmanda Tang IntegerAttr axis = adaptor.getProperties().axis; 6245a4e7760SRob Suderman int32_t axisVal = axis.getValue().getSExtValue(); 6255a4e7760SRob Suderman 62609349303SJacques Pienaar if (!inputShape.hasRank()) { 6275a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents()); 6285a4e7760SRob Suderman return success(); 6295a4e7760SRob Suderman } 6305a4e7760SRob Suderman 6315a4e7760SRob Suderman SmallVector<int64_t> outShape; 63209349303SJacques Pienaar outShape.reserve(inputShape.getRank() - 1); 63309349303SJacques Pienaar for (int i = 0, s = inputShape.getRank(); i < s; i++) { 6345a4e7760SRob Suderman if (i == axisVal) 6355a4e7760SRob Suderman continue; 63609349303SJacques Pienaar outShape.push_back(inputShape.getDimSize(i)); 6375a4e7760SRob Suderman } 6385a4e7760SRob Suderman 6395a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 6405a4e7760SRob Suderman return success(); 6415a4e7760SRob Suderman } 6425a4e7760SRob Suderman 64394f255c2SLuke Hutton LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( 64494f255c2SLuke Hutton MLIRContext *context, ::std::optional<Location> location, 645057fc8e7SAmanda Tang RFFT2dOp::Adaptor adaptor, 64694f255c2SLuke Hutton SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 647057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 64894f255c2SLuke Hutton 64994f255c2SLuke Hutton if (!inputShape.hasRank()) 65094f255c2SLuke Hutton return failure(); 65194f255c2SLuke Hutton 65294f255c2SLuke Hutton llvm::SmallVector<int64_t> outputShape; 65394f255c2SLuke Hutton outputShape.resize(3, ShapedType::kDynamic); 65494f255c2SLuke Hutton outputShape[0] = inputShape.getDimSize(0); 65594f255c2SLuke Hutton outputShape[1] = inputShape.getDimSize(1); 65694f255c2SLuke Hutton int64_t inWidth = inputShape.getDimSize(2); 65794f255c2SLuke Hutton 65894f255c2SLuke Hutton // Note that we can support this calculation symbolically 65994f255c2SLuke Hutton // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] 66094f255c2SLuke Hutton if (inWidth != ShapedType::kDynamic) 66194f255c2SLuke Hutton outputShape[2] = inWidth / 2 + 1; 66294f255c2SLuke Hutton 66394f255c2SLuke Hutton inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 66494f255c2SLuke Hutton inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 66531293886SLuke Hutton 66631293886SLuke Hutton return success(); 66731293886SLuke Hutton } 66831293886SLuke Hutton 66931293886SLuke Hutton LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( 67031293886SLuke Hutton MLIRContext *context, ::std::optional<Location> location, 671057fc8e7SAmanda Tang FFT2dOp::Adaptor adaptor, 67231293886SLuke Hutton SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 673057fc8e7SAmanda Tang inferredReturnShapes.push_back( 674057fc8e7SAmanda Tang ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType()))); 675057fc8e7SAmanda Tang inferredReturnShapes.push_back( 676057fc8e7SAmanda Tang ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType()))); 67794f255c2SLuke Hutton return success(); 67894f255c2SLuke Hutton } 67994f255c2SLuke Hutton 6805a4e7760SRob Suderman LogicalResult tosa::ConcatOp::inferReturnTypeComponents( 68122426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 682057fc8e7SAmanda Tang ConcatOp::Adaptor adaptor, 6835a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 6845a4e7760SRob Suderman // Infer all dimension sizes by reducing based on inputs. 685057fc8e7SAmanda Tang const Properties &prop = adaptor.getProperties(); 686057fc8e7SAmanda Tang int32_t axis = prop.axis.getValue().getSExtValue(); 6875a4e7760SRob Suderman llvm::SmallVector<int64_t> outputShape; 6885a4e7760SRob Suderman bool hasRankedInput = false; 689057fc8e7SAmanda Tang for (auto operand : adaptor.getOperands()) { 690057fc8e7SAmanda Tang ShapeAdaptor operandShape(operand.getType()); 69109349303SJacques Pienaar if (!operandShape.hasRank()) 6925a4e7760SRob Suderman continue; 6935a4e7760SRob Suderman 6945a4e7760SRob Suderman // Copy the Operand's rank. 6955a4e7760SRob Suderman if (!hasRankedInput) 696399638f9SAliia Khasanova outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); 6975a4e7760SRob Suderman 6985a4e7760SRob Suderman // Copy shapes until the dim is non-dynamic. 69909349303SJacques Pienaar for (int i = 0, s = operandShape.getRank(); i < s; i++) { 70009349303SJacques Pienaar if (i == axis || operandShape.isDynamicDim(i)) 7015a4e7760SRob Suderman continue; 702399638f9SAliia Khasanova if (outputShape[i] == ShapedType::kDynamic) 70309349303SJacques Pienaar outputShape[i] = operandShape.getDimSize(i); 70409349303SJacques Pienaar if (outputShape[i] != operandShape.getDimSize(i)) 705fd004a49SMaya Amrami return emitOptionalError(location, 706fd004a49SMaya Amrami "Cannot concat tensors with different sizes" 707fd004a49SMaya Amrami " on the non-axis dimension ", 708fd004a49SMaya Amrami i); 7095a4e7760SRob Suderman } 7105a4e7760SRob Suderman 7115a4e7760SRob Suderman hasRankedInput = true; 7125a4e7760SRob Suderman } 713c1fa60b4STres Popp Type inputType = 714057fc8e7SAmanda Tang llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType(); 7155a4e7760SRob Suderman if (!hasRankedInput) { 716fd004a49SMaya Amrami inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); 7175a4e7760SRob Suderman return success(); 7185a4e7760SRob Suderman } 7195a4e7760SRob Suderman 7205a4e7760SRob Suderman // Determine the dimension size along the concatenation axis. 721fb4cedccSAliia Khasanova int64_t concatDimSize = 0; 722057fc8e7SAmanda Tang for (auto operand : adaptor.getOperands()) { 723057fc8e7SAmanda Tang ShapeAdaptor operandShape(operand.getType()); 7245a4e7760SRob Suderman 7255a4e7760SRob Suderman // We need to know the length of the concatenation axis of all inputs to 7265a4e7760SRob Suderman // determine the dimension size of the output shape. 72709349303SJacques Pienaar if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { 728399638f9SAliia Khasanova concatDimSize = ShapedType::kDynamic; 7295a4e7760SRob Suderman break; 7305a4e7760SRob Suderman } 7315a4e7760SRob Suderman 73209349303SJacques Pienaar concatDimSize += operandShape.getDimSize(axis); 7335a4e7760SRob Suderman } 7345a4e7760SRob Suderman 7355a4e7760SRob Suderman outputShape[axis] = concatDimSize; 7365a4e7760SRob Suderman 737fd004a49SMaya Amrami inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); 7385a4e7760SRob Suderman return success(); 7395a4e7760SRob Suderman } 7405a4e7760SRob Suderman 741b73e8325Snot-jenni LogicalResult tosa::EqualOp::inferReturnTypeComponents( 74222426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 7435e118f93SMehdi Amini ValueShapeRange operands, DictionaryAttr attributes, 7445e118f93SMehdi Amini OpaqueProperties properties, RegionRange regions, 745b73e8325Snot-jenni SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 74619109a27SSpenser Bauman auto elementType = IntegerType::get(context, /*width=*/1); 74719109a27SSpenser Bauman 748b73e8325Snot-jenni llvm::SmallVector<int64_t> outShape; 749b73e8325Snot-jenni if (resolveBroadcastShape(operands, outShape).failed()) { 75019109a27SSpenser Bauman inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); 751b73e8325Snot-jenni return success(); 752b73e8325Snot-jenni } 753b73e8325Snot-jenni 75419109a27SSpenser Bauman inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); 755b73e8325Snot-jenni return success(); 756b73e8325Snot-jenni } 757b73e8325Snot-jenni 758b73e8325Snot-jenni bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 759b73e8325Snot-jenni if (l.size() != r.size() || l.size() != 1) 760b73e8325Snot-jenni return false; 761b73e8325Snot-jenni return succeeded(verifyCompatibleShape(l[0], r[0])); 762b73e8325Snot-jenni } 763b73e8325Snot-jenni 7645a4e7760SRob Suderman LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( 76522426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 766057fc8e7SAmanda Tang FullyConnectedOp::Adaptor adaptor, 7675a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 768057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 769057fc8e7SAmanda Tang ShapeAdaptor weightShape(adaptor.getWeight().getType()); 770057fc8e7SAmanda Tang ShapeAdaptor biasShape(adaptor.getBias().getType()); 7715a4e7760SRob Suderman 7725a4e7760SRob Suderman // All shapes are dynamic. 7735a4e7760SRob Suderman SmallVector<int64_t> outShape; 774399638f9SAliia Khasanova outShape.resize(2, ShapedType::kDynamic); 7755a4e7760SRob Suderman 77609349303SJacques Pienaar if (inputShape.hasRank()) { 77709349303SJacques Pienaar outShape[0] = inputShape.getDimSize(0); 7785a4e7760SRob Suderman } 7795a4e7760SRob Suderman 78009349303SJacques Pienaar if (weightShape.hasRank()) { 78109349303SJacques Pienaar outShape[1] = weightShape.getDimSize(0); 7825a4e7760SRob Suderman } 7835a4e7760SRob Suderman 78409349303SJacques Pienaar if (biasShape.hasRank()) { 78522426110SRamkumar Ramachandra outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) 786143edecaSRob Suderman : outShape[1]; 7875a4e7760SRob Suderman } 7885a4e7760SRob Suderman 7895a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 7905a4e7760SRob Suderman return success(); 7915a4e7760SRob Suderman } 7925a4e7760SRob Suderman 7931be88f5aSRiver Riddle LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } 7941be88f5aSRiver Riddle 7955a4e7760SRob Suderman LogicalResult tosa::MatMulOp::inferReturnTypeComponents( 79622426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 797057fc8e7SAmanda Tang MatMulOp::Adaptor adaptor, 7985a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 799057fc8e7SAmanda Tang ShapeAdaptor lhsShape(adaptor.getA().getType()); 800057fc8e7SAmanda Tang ShapeAdaptor rhsShape(adaptor.getB().getType()); 8015a4e7760SRob Suderman 8025a4e7760SRob Suderman // All shapes are dynamic. 8035a4e7760SRob Suderman SmallVector<int64_t> outShape; 804399638f9SAliia Khasanova outShape.resize(3, ShapedType::kDynamic); 8055a4e7760SRob Suderman 80609349303SJacques Pienaar if (lhsShape.hasRank()) { 80709349303SJacques Pienaar outShape[0] = lhsShape.getDimSize(0); 80809349303SJacques Pienaar outShape[1] = lhsShape.getDimSize(1); 8095a4e7760SRob Suderman } 8105a4e7760SRob Suderman 81109349303SJacques Pienaar if (rhsShape.hasRank()) { 81222426110SRamkumar Ramachandra outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) 813143edecaSRob Suderman : outShape[0]; 81409349303SJacques Pienaar outShape[2] = rhsShape.getDimSize(2); 8155a4e7760SRob Suderman } 8165a4e7760SRob Suderman 8175a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 8185a4e7760SRob Suderman return success(); 8195a4e7760SRob Suderman } 8205a4e7760SRob Suderman 8215a4e7760SRob Suderman LogicalResult tosa::PadOp::inferReturnTypeComponents( 82222426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 823057fc8e7SAmanda Tang PadOp::Adaptor adaptor, 8245a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 825057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput1().getType()); 8267e622b61SJerry-Ge auto paddingRank = 8277e622b61SJerry-Ge cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank(); 8285a4e7760SRob Suderman SmallVector<int64_t> outputShape; 8295a4e7760SRob Suderman 8307e622b61SJerry-Ge // If the input rank is unknown, we can infer the output rank using the 8317e622b61SJerry-Ge // padding shape's rank divided by 2. 83209349303SJacques Pienaar if (!inputShape.hasRank()) { 8337e622b61SJerry-Ge outputShape.resize(paddingRank / 2, ShapedType::kDynamic); 8345a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 8355a4e7760SRob Suderman return success(); 8365a4e7760SRob Suderman } 8375a4e7760SRob Suderman 8385a4e7760SRob Suderman SmallVector<int64_t> paddingValues; 8397e622b61SJerry-Ge // If the paddings value is not a constant, all dimensions must be dynamic. 8407e622b61SJerry-Ge if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(), 8417e622b61SJerry-Ge paddingValues)) { 8427e622b61SJerry-Ge outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); 8437e622b61SJerry-Ge inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 8447e622b61SJerry-Ge return success(); 8455a4e7760SRob Suderman } 8465a4e7760SRob Suderman 84709349303SJacques Pienaar outputShape.reserve(inputShape.getRank()); 84809349303SJacques Pienaar for (int i = 0, s = inputShape.getRank(); i < s; i++) { 84909349303SJacques Pienaar if (inputShape.isDynamicDim(i)) { 850399638f9SAliia Khasanova outputShape.push_back(ShapedType::kDynamic); 8515a4e7760SRob Suderman continue; 8525a4e7760SRob Suderman } 8537e622b61SJerry-Ge auto padFront = paddingValues[i * 2]; 8547e622b61SJerry-Ge auto padBack = paddingValues[i * 2 + 1]; 8557e622b61SJerry-Ge if (padFront < 0 || padBack < 0) { 8567e622b61SJerry-Ge // if either padding for dim i is -1, output dim is unknown 8577e622b61SJerry-Ge outputShape.push_back(ShapedType::kDynamic); 8587e622b61SJerry-Ge continue; 8597e622b61SJerry-Ge } 8605a4e7760SRob Suderman 8617e622b61SJerry-Ge outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack); 8625a4e7760SRob Suderman } 8635a4e7760SRob Suderman 8645a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 8655a4e7760SRob Suderman return success(); 8665a4e7760SRob Suderman } 8675a4e7760SRob Suderman 86837263b6cSLongsheng Mou LogicalResult tosa::PadOp::verify() { 86937263b6cSLongsheng Mou RankedTensorType inputType = getInput1().getType(); 87037263b6cSLongsheng Mou RankedTensorType outputType = getOutput().getType(); 8717e622b61SJerry-Ge auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank(); 87237263b6cSLongsheng Mou 87337263b6cSLongsheng Mou if (inputType.getRank() != outputType.getRank()) 87437263b6cSLongsheng Mou return emitOpError() << "expect same input and output tensor rank."; 87537263b6cSLongsheng Mou 8767e622b61SJerry-Ge if (paddingRank != inputType.getRank() * 2) 877c1d01b2fSLongsheng Mou return emitOpError() << "expected padding tensor dim 0 to have size " 878c1d01b2fSLongsheng Mou << inputType.getRank() * 2 8797e622b61SJerry-Ge << " (2*rank(shape1)) but got size " << paddingRank; 88037263b6cSLongsheng Mou 88137263b6cSLongsheng Mou return success(); 88237263b6cSLongsheng Mou } 88337263b6cSLongsheng Mou 884fa5a607dSRob Suderman static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) { 885fa5a607dSRob Suderman return to_vector(llvm::map_range(shape, [](int64_t dim) { 886399638f9SAliia Khasanova return dim == -1 ? ShapedType::kDynamic : dim; 887fa5a607dSRob Suderman })); 888fa5a607dSRob Suderman } 889fa5a607dSRob Suderman 8905a4e7760SRob Suderman LogicalResult tosa::SliceOp::inferReturnTypeComponents( 89122426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 892057fc8e7SAmanda Tang SliceOp::Adaptor adaptor, 8935a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 894*956c0707SJerry-Ge 895*956c0707SJerry-Ge Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); 896*956c0707SJerry-Ge SmallVector<int64_t> start; 897*956c0707SJerry-Ge SmallVector<int64_t> size; 898*956c0707SJerry-Ge 899*956c0707SJerry-Ge if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) || 900*956c0707SJerry-Ge !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) { 901*956c0707SJerry-Ge auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank(); 902*956c0707SJerry-Ge SmallVector<int64_t> fallback(rank, ShapedType::kDynamic); 903*956c0707SJerry-Ge inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); 904*956c0707SJerry-Ge return success(); 905*956c0707SJerry-Ge } 9067986e0caSTai Ly 9077986e0caSTai Ly // if size[i] is -1, all remaining elements in dimension i are included 9087986e0caSTai Ly // in the slice, similar to TF. 9097986e0caSTai Ly ShapeAdaptor inputShape(adaptor.getInput1().getType()); 9107986e0caSTai Ly // initialize outputShape to all unknown 9117986e0caSTai Ly SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic); 9127986e0caSTai Ly if (inputShape.hasRank()) { 9137986e0caSTai Ly for (size_t i = 0; i < size.size(); i++) { 9147986e0caSTai Ly if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 && 9157986e0caSTai Ly (ShapedType::isDynamic(inputShape.getDimSize(i)) || 9167986e0caSTai Ly start[i] < inputShape.getDimSize(i))) { 9177986e0caSTai Ly // size[i] is not 0 and not < -1, and start[i] is in valid range 9187986e0caSTai Ly if (ShapedType::isDynamic(inputShape.getDimSize(i))) { 9197986e0caSTai Ly // input shape has unknown dim[i] - only valid if size[i] > 0 9207986e0caSTai Ly if (size[i] > 0) { 9217986e0caSTai Ly outputShape[i] = size[i]; 9227986e0caSTai Ly } 9237986e0caSTai Ly } else { 9247986e0caSTai Ly // input shape has known dim[i] 9257986e0caSTai Ly if (size[i] == -1) { 9267986e0caSTai Ly outputShape[i] = inputShape.getDimSize(i) - start[i]; 9277986e0caSTai Ly } else if (start[i] + size[i] <= inputShape.getDimSize(i)) { 9287986e0caSTai Ly // start[i] + size[i] is within bound of input shape's dim[i] 9297986e0caSTai Ly outputShape[i] = size[i]; 9307986e0caSTai Ly } 9317986e0caSTai Ly } 9327986e0caSTai Ly } 9337986e0caSTai Ly } 9347986e0caSTai Ly } else { 9357986e0caSTai Ly outputShape = convertToMlirShape(size); 9367986e0caSTai Ly } 9377986e0caSTai Ly inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 9385a4e7760SRob Suderman return success(); 9395a4e7760SRob Suderman } 9405a4e7760SRob Suderman 941ab7e8b76Slong.chen LogicalResult tosa::SliceOp::verify() { 942c6876b4eSJerry-Ge auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType()); 943ab7e8b76Slong.chen if (!inputType) 944ab7e8b76Slong.chen return success(); 945ab7e8b76Slong.chen 946*956c0707SJerry-Ge auto startShapeRank = 947*956c0707SJerry-Ge llvm::cast<tosa::shapeType>(getStart().getType()).getRank(); 948*956c0707SJerry-Ge if (inputType.getRank() != startShapeRank) 949ab7e8b76Slong.chen return emitOpError( 950ab7e8b76Slong.chen "length of start attribute is not equal rank of input shape"); 951ab7e8b76Slong.chen 952*956c0707SJerry-Ge auto sizeShapeRank = 953*956c0707SJerry-Ge llvm::cast<tosa::shapeType>(getSize().getType()).getRank(); 954*956c0707SJerry-Ge if (inputType.getRank() != sizeShapeRank) 955ab7e8b76Slong.chen return emitOpError( 956ab7e8b76Slong.chen "length of size attribute is not equal rank of input shape"); 957ab7e8b76Slong.chen 958ab7e8b76Slong.chen return success(); 959ab7e8b76Slong.chen } 960ab7e8b76Slong.chen 961519eef3bSLongsheng Mou LogicalResult tosa::MulOp::verify() { 962a58e774fSJack Frankland auto resElemType = getElementTypeOrSelf(getOutput()); 963a58e774fSJack Frankland 964a58e774fSJack Frankland // Verify if the element type among operands and result match tosa 965a58e774fSJack Frankland // specification. 966a58e774fSJack Frankland if (auto resIntType = dyn_cast<IntegerType>(resElemType)) { 967a58e774fSJack Frankland IntegerType lhsIntType = 968a58e774fSJack Frankland cast<IntegerType>(getElementTypeOrSelf(getInput1())); 969a58e774fSJack Frankland IntegerType rhsIntType = 970a58e774fSJack Frankland cast<IntegerType>(getElementTypeOrSelf(getInput2())); 971a58e774fSJack Frankland if (lhsIntType != rhsIntType) 972a58e774fSJack Frankland return emitOpError("requires the same element type for all operands"); 973a58e774fSJack Frankland 974a58e774fSJack Frankland // Though the spec requires the element type of result to be i32, a more 975a58e774fSJack Frankland // relaxed way is provided at dialect level for easier cooperating with 976a58e774fSJack Frankland // other dialects. 977a58e774fSJack Frankland if (lhsIntType.getWidth() > resIntType.getWidth()) 978a58e774fSJack Frankland return emitOpError("invalid data type size for operands or result"); 979a58e774fSJack Frankland 980a58e774fSJack Frankland } else { 981a58e774fSJack Frankland // For other supported type, the spec requires requires the same element 982a58e774fSJack Frankland // type for all operands (excludes `shift` operand) and results. 983a58e774fSJack Frankland for (int i = 0; i < 2; ++i) { 984a58e774fSJack Frankland if (getElementTypeOrSelf(getOperand(i)) != resElemType) 985a58e774fSJack Frankland return emitOpError( 986a58e774fSJack Frankland "requires the same element type for all operands and results"); 987a58e774fSJack Frankland } 988a58e774fSJack Frankland } 989a58e774fSJack Frankland 990a58e774fSJack Frankland // Verify the op has same ranks for all main operands (excludes extra operands 991a58e774fSJack Frankland // such as shift of mul op, so this is the only difference with the built-in 992a58e774fSJack Frankland // `SameOperandsAndResultRank` trait) and results types, if known. 993a58e774fSJack Frankland 994a58e774fSJack Frankland // delegate function that returns true if type is a shaped type with known 995a58e774fSJack Frankland // rank 996a58e774fSJack Frankland auto hasRank = [](const Type type) { 997a58e774fSJack Frankland if (auto shaped_type = dyn_cast<ShapedType>(type)) 998a58e774fSJack Frankland return shaped_type.hasRank(); 999a58e774fSJack Frankland 1000a58e774fSJack Frankland return false; 1001a58e774fSJack Frankland }; 1002a58e774fSJack Frankland 1003a58e774fSJack Frankland auto rankedOperandTypes = 1004a58e774fSJack Frankland llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank)); 1005a58e774fSJack Frankland 1006a58e774fSJack Frankland auto rankedResultTypes = 1007a58e774fSJack Frankland llvm::make_filter_range(getOperation()->getResultTypes(), hasRank); 1008a58e774fSJack Frankland 1009a58e774fSJack Frankland // If all operands and results are unranked, then no further verification. 1010a58e774fSJack Frankland if (rankedOperandTypes.empty() && rankedResultTypes.empty()) 1011a58e774fSJack Frankland return success(); 1012a58e774fSJack Frankland 1013a58e774fSJack Frankland // delegate function that returns rank of shaped type with known rank 1014a58e774fSJack Frankland auto getRank = [](const Type type) { 1015a58e774fSJack Frankland return cast<ShapedType>(type).getRank(); 1016a58e774fSJack Frankland }; 1017a58e774fSJack Frankland 1018a58e774fSJack Frankland auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) 1019a58e774fSJack Frankland : getRank(*rankedResultTypes.begin()); 1020a58e774fSJack Frankland 1021a58e774fSJack Frankland for (size_t i = 0; i < 2; ++i) { 1022a58e774fSJack Frankland if (rank != getRank(rankedOperandTypes[i])) { 1023a58e774fSJack Frankland return emitOpError("operands don't have matching ranks"); 1024a58e774fSJack Frankland } 1025a58e774fSJack Frankland } 1026a58e774fSJack Frankland 1027a58e774fSJack Frankland for (const auto type : rankedResultTypes) { 1028a58e774fSJack Frankland if (rank != getRank(type)) { 1029a58e774fSJack Frankland return emitOpError("result type has different rank than operands"); 1030a58e774fSJack Frankland } 1031a58e774fSJack Frankland } 1032519eef3bSLongsheng Mou 1033519eef3bSLongsheng Mou return success(); 1034519eef3bSLongsheng Mou } 1035519eef3bSLongsheng Mou 10365a4e7760SRob Suderman LogicalResult tosa::TableOp::inferReturnTypeComponents( 103722426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1038057fc8e7SAmanda Tang TableOp::Adaptor adaptor, 10395a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1040c6876b4eSJerry-Ge ShapeAdaptor inputShape(adaptor.getInput1().getType()); 10415a4e7760SRob Suderman 104209349303SJacques Pienaar if (!inputShape.hasRank()) { 10435a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents()); 10445a4e7760SRob Suderman return success(); 10455a4e7760SRob Suderman } 10465a4e7760SRob Suderman 104709349303SJacques Pienaar inferredReturnShapes.resize(1); 104809349303SJacques Pienaar inputShape.getDims(inferredReturnShapes[0]); 10495a4e7760SRob Suderman return success(); 10505a4e7760SRob Suderman } 10515a4e7760SRob Suderman 10521e347062SLongsheng Mou LogicalResult tosa::TableOp::verify() { 1053c6876b4eSJerry-Ge TensorType inputType = getInput1().getType(); 10541e347062SLongsheng Mou TensorType outputType = getOutput().getType(); 10551e347062SLongsheng Mou 10561e347062SLongsheng Mou if (inputType.hasRank() && outputType.hasRank() && 10571e347062SLongsheng Mou inputType.getRank() != outputType.getRank()) 10581e347062SLongsheng Mou return emitOpError() 10591e347062SLongsheng Mou << "expected input tensor rank to equal result tensor rank"; 10601e347062SLongsheng Mou 10611e347062SLongsheng Mou auto inputDims = inputType.getShape(); 10621e347062SLongsheng Mou auto outputDims = outputType.getShape(); 10631e347062SLongsheng Mou for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) { 10641e347062SLongsheng Mou int64_t dim = it.index(); 10651e347062SLongsheng Mou auto [inputDim, outputDim] = it.value(); 10661e347062SLongsheng Mou if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) { 10671e347062SLongsheng Mou return emitOpError() << "dim(result, " << dim << ") = " << outputDim 10681e347062SLongsheng Mou << " doesn't match dim(input, " << dim 10691e347062SLongsheng Mou << ") = " << inputDim; 10701e347062SLongsheng Mou } 10711e347062SLongsheng Mou } 10721e347062SLongsheng Mou return success(); 10731e347062SLongsheng Mou } 10741e347062SLongsheng Mou 1075f09db6a3SJerry-Ge LogicalResult 1076f09db6a3SJerry-Ge tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) { 1077f09db6a3SJerry-Ge // Multiples must be constants. 1078f09db6a3SJerry-Ge DenseIntElementsAttr multiplesAttr; 1079f09db6a3SJerry-Ge if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) 1080f09db6a3SJerry-Ge return failure(); 1081f09db6a3SJerry-Ge multiples = llvm::to_vector( 1082f09db6a3SJerry-Ge llvm::map_range(multiplesAttr.getValues<APInt>(), 1083f09db6a3SJerry-Ge [](const APInt &val) { return val.getSExtValue(); })); 1084f09db6a3SJerry-Ge return success(); 1085f09db6a3SJerry-Ge } 1086f09db6a3SJerry-Ge 10875a4e7760SRob Suderman LogicalResult tosa::TileOp::inferReturnTypeComponents( 108822426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1089057fc8e7SAmanda Tang TileOp::Adaptor adaptor, 10905a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1091f09db6a3SJerry-Ge DenseIntElementsAttr multiplesAttr; 1092f09db6a3SJerry-Ge if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr))) 1093f09db6a3SJerry-Ge return failure(); 1094f09db6a3SJerry-Ge 1095f09db6a3SJerry-Ge SmallVector<int64_t> multiples = llvm::to_vector( 1096f09db6a3SJerry-Ge llvm::map_range(multiplesAttr.getValues<APInt>(), 1097f09db6a3SJerry-Ge [](const APInt &val) { return val.getSExtValue(); })); 1098f09db6a3SJerry-Ge 1099057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput1().getType()); 11005a4e7760SRob Suderman SmallVector<int64_t> outputShape; 110109349303SJacques Pienaar if (!inputShape.hasRank()) { 1102399638f9SAliia Khasanova outputShape.resize(multiples.size(), ShapedType::kDynamic); 11035a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 11045a4e7760SRob Suderman return success(); 11050e6c679cSKazu Hirata } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size()) 1106b6d67af2SFelix Schneider return failure(); 11075a4e7760SRob Suderman 11085a4e7760SRob Suderman // Any non dynamic dimension can be multiplied to a known size. 11095a4e7760SRob Suderman outputShape.reserve(multiples.size()); 111009349303SJacques Pienaar for (int i = 0, s = inputShape.getRank(); i < s; i++) { 1111fb4cedccSAliia Khasanova int64_t dim = inputShape.getDimSize(i); 1112399638f9SAliia Khasanova if (dim != ShapedType::kDynamic) 11139e1a3441SAlexander Shaposhnikov dim *= multiples[i]; 11145a4e7760SRob Suderman outputShape.push_back(dim); 11155a4e7760SRob Suderman } 11165a4e7760SRob Suderman 11175a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 11185a4e7760SRob Suderman return success(); 11195a4e7760SRob Suderman } 11205a4e7760SRob Suderman 1121b6d67af2SFelix Schneider LogicalResult tosa::TileOp::verify() { 1122b6d67af2SFelix Schneider ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType()); 1123b6d67af2SFelix Schneider ShapedType outputType = llvm::cast<ShapedType>(getType()); 1124f09db6a3SJerry-Ge 1125f09db6a3SJerry-Ge shapeType multiplesType = 1126f09db6a3SJerry-Ge llvm::cast<tosa::shapeType>(getMultiples().getType()); 1127f09db6a3SJerry-Ge 1128f09db6a3SJerry-Ge auto multiplesRank = multiplesType.getRank(); 1129b6d67af2SFelix Schneider 1130b6d67af2SFelix Schneider if (inputType.hasRank()) { 1131f09db6a3SJerry-Ge if (inputType.getRank() != multiplesRank) 1132f09db6a3SJerry-Ge return emitOpError("expect 'multiples' to have rank ") 1133f09db6a3SJerry-Ge << inputType.getRank() << " but got " << multiplesRank << "."; 1134b6d67af2SFelix Schneider if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) 1135b6d67af2SFelix Schneider return emitOpError("expect same input and output tensor rank."); 1136f09db6a3SJerry-Ge } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) 1137b6d67af2SFelix Schneider return emitOpError("expect 'multiples' array to have length ") 1138f09db6a3SJerry-Ge << outputType.getRank() << " but got " << multiplesRank << "."; 1139b6d67af2SFelix Schneider 1140f09db6a3SJerry-Ge SmallVector<int64_t> multiples; 1141f09db6a3SJerry-Ge if (getConstantMultiples(multiples).succeeded() && 1142f09db6a3SJerry-Ge llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) 1143c8568f09SLongsheng Mou return emitOpError( 1144c8568f09SLongsheng Mou "expect element of 'multiples' to be positive integer or -1."); 1145c8568f09SLongsheng Mou 1146b6d67af2SFelix Schneider return success(); 1147b6d67af2SFelix Schneider } 1148b6d67af2SFelix Schneider 11492dd396c1SAviad Cohen bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 11502dd396c1SAviad Cohen if (l.size() != r.size() || l.size() != 1) 11512dd396c1SAviad Cohen return false; 11522dd396c1SAviad Cohen return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); 11532dd396c1SAviad Cohen } 11542dd396c1SAviad Cohen 11558dea784bSRob Suderman LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( 115622426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1157057fc8e7SAmanda Tang ReshapeOp::Adaptor adaptor, 11588dea784bSRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1159057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput1().getType()); 1160057fc8e7SAmanda Tang Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); 11619e1a3441SAlexander Shaposhnikov llvm::SmallVector<int64_t> newShapeValue = 11629e1a3441SAlexander Shaposhnikov convertToMlirShape(adaptor.getNewShape()); 11638dea784bSRob Suderman 11648dea784bSRob Suderman // We cannot infer from the total number of elements so we must take the 11658dea784bSRob Suderman // shape attribute as exact. 116609349303SJacques Pienaar if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { 11672dd396c1SAviad Cohen inferredReturnShapes.push_back( 11682dd396c1SAviad Cohen ShapedTypeComponents(newShapeValue, inputType)); 11698dea784bSRob Suderman return success(); 11708dea784bSRob Suderman } 11718dea784bSRob Suderman 11728dea784bSRob Suderman // Determine the number of elements covered by the slice of all static 11738dea784bSRob Suderman // dimensions. This allows us to infer the length of the remaining dynamic 11748dea784bSRob Suderman // dimension. 117509349303SJacques Pienaar int64_t numElements = inputShape.getNumElements(); 11768dea784bSRob Suderman int64_t staticMul = 1; 11778dea784bSRob Suderman for (auto val : newShapeValue) { 1178fb4cedccSAliia Khasanova if (!ShapedType::isDynamic(val)) { 11798dea784bSRob Suderman staticMul *= val; 11808dea784bSRob Suderman } 11818dea784bSRob Suderman } 11828dea784bSRob Suderman 11838dea784bSRob Suderman // Determine the length of the dynamic dimension. 11848dea784bSRob Suderman for (auto &val : newShapeValue) { 1185fb4cedccSAliia Khasanova if (ShapedType::isDynamic(val)) 11868dea784bSRob Suderman val = numElements / staticMul; 11878dea784bSRob Suderman } 11888dea784bSRob Suderman 11892dd396c1SAviad Cohen inferredReturnShapes.push_back( 11902dd396c1SAviad Cohen ShapedTypeComponents(newShapeValue, inputType)); 11918dea784bSRob Suderman return success(); 11928dea784bSRob Suderman } 11938dea784bSRob Suderman 1194db791b27SRamkumar Ramachandra llvm::LogicalResult tosa::ReshapeOp::verify() { 1195fbcd0c65SRafael Ubal TensorType inputType = getInput1().getType(); 1196fbcd0c65SRafael Ubal RankedTensorType outputType = getType(); 1197a315534eSa.puschin 1198fbcd0c65SRafael Ubal if ((int64_t)getNewShape().size() != outputType.getRank()) 1199fbcd0c65SRafael Ubal return emitOpError() << "new shape does not match result rank"; 1200fbcd0c65SRafael Ubal 1201fbcd0c65SRafael Ubal for (auto [newShapeDim, outputShapeDim] : 1202019fbcc4SLongsheng Mou zip(getNewShape(), outputType.getShape())) { 1203fbcd0c65SRafael Ubal if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic && 1204fbcd0c65SRafael Ubal newShapeDim != outputShapeDim) 1205fbcd0c65SRafael Ubal return emitOpError() << "new shape is inconsistent with result shape"; 1206fbcd0c65SRafael Ubal 1207019fbcc4SLongsheng Mou if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1) 1208019fbcc4SLongsheng Mou return emitOpError() << "new shape has invalid tensor dimension size " 1209019fbcc4SLongsheng Mou << newShapeDim; 1210019fbcc4SLongsheng Mou } 1211019fbcc4SLongsheng Mou 1212e6eb94d3SLongsheng Mou if (inputType.hasStaticShape()) { 1213a315534eSa.puschin int64_t inputElementsNum = inputType.getNumElements(); 1214e6eb94d3SLongsheng Mou if (outputType.hasStaticShape()) { 1215a315534eSa.puschin int64_t outputElementsNum = outputType.getNumElements(); 1216a315534eSa.puschin if (inputElementsNum != outputElementsNum) { 1217fbcd0c65SRafael Ubal return emitOpError() << "cannot reshape " << inputElementsNum 1218a315534eSa.puschin << " elements into " << outputElementsNum; 1219a315534eSa.puschin } 1220a315534eSa.puschin } 122126d896f3SRafael Ubal 1222e6eb94d3SLongsheng Mou int64_t newShapeElementsNum = std::accumulate( 1223e6eb94d3SLongsheng Mou getNewShape().begin(), getNewShape().end(), 1LL, 1224e6eb94d3SLongsheng Mou [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); 1225e6eb94d3SLongsheng Mou bool isStaticNewShape = 1226e6eb94d3SLongsheng Mou llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; }); 1227e6eb94d3SLongsheng Mou if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || 1228e6eb94d3SLongsheng Mou (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { 1229e6eb94d3SLongsheng Mou return emitOpError() << "cannot reshape " << inputElementsNum 1230e6eb94d3SLongsheng Mou << " elements into " << newShapeElementsNum; 1231e6eb94d3SLongsheng Mou } 1232e6eb94d3SLongsheng Mou } 1233e6eb94d3SLongsheng Mou 123426d896f3SRafael Ubal int missingDims = llvm::count(getNewShape(), -1); 123526d896f3SRafael Ubal if (missingDims > 1) 1236fbcd0c65SRafael Ubal return emitOpError() << "expected at most one target dimension to be -1"; 123726d896f3SRafael Ubal 1238a315534eSa.puschin return mlir::success(); 1239a315534eSa.puschin } 1240a315534eSa.puschin 1241a54efdbdSArteen Abrishami LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) { 1242ded988edSAviad Cohen // Perms must be constants. 1243ded988edSAviad Cohen DenseIntElementsAttr permsAttr; 1244ded988edSAviad Cohen if (!matchPattern(getPerms(), m_Constant(&permsAttr))) 1245ded988edSAviad Cohen return failure(); 1246ded988edSAviad Cohen 1247a54efdbdSArteen Abrishami perms.clear(); 1248a54efdbdSArteen Abrishami for (auto v : permsAttr.getValues<APInt>()) 1249a54efdbdSArteen Abrishami perms.push_back(v.getSExtValue()); 1250ded988edSAviad Cohen 1251ded988edSAviad Cohen return success(); 1252ded988edSAviad Cohen } 1253ded988edSAviad Cohen 12545a4e7760SRob Suderman LogicalResult tosa::TransposeOp::inferReturnTypeComponents( 125522426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1256057fc8e7SAmanda Tang TransposeOp::Adaptor adaptor, 12575a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1258057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput1().getType()); 1259057fc8e7SAmanda Tang ShapeAdaptor permsShape(adaptor.getPerms().getType()); 12605a4e7760SRob Suderman 12617563eb64SFelix Schneider // We cannot infer anything from a rank-0 "permutation" tensor. 12627563eb64SFelix Schneider if (permsShape.hasRank() && permsShape.getRank() == 0) 12637563eb64SFelix Schneider return failure(); 12647563eb64SFelix Schneider 12655a4e7760SRob Suderman // If input rank and permutation length is unknown, the output rank is 12665a4e7760SRob Suderman // unknown. 1267826d3eaaSRob Suderman if (!inputShape.hasRank() || !permsShape.hasRank() || 1268826d3eaaSRob Suderman permsShape.isDynamicDim(0)) { 12695a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents()); 12705a4e7760SRob Suderman return success(); 12715a4e7760SRob Suderman } 12725a4e7760SRob Suderman 1273a54efdbdSArteen Abrishami // This would imply the number of permutations does not match the rank of 1274a54efdbdSArteen Abrishami // the input which is illegal. 1275826d3eaaSRob Suderman if (permsShape.getDimSize(0) != inputShape.getRank()) { 1276826d3eaaSRob Suderman return failure(); 1277826d3eaaSRob Suderman } 1278826d3eaaSRob Suderman 12795a4e7760SRob Suderman SmallVector<int64_t> outputShape; 12805a4e7760SRob Suderman // Rank-0 means no permutations matter. 128109349303SJacques Pienaar if (inputShape.getRank() == 0) { 12825a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 12835a4e7760SRob Suderman return success(); 12845a4e7760SRob Suderman } 12855a4e7760SRob Suderman 12865a4e7760SRob Suderman // Check whether the input dimensions are all the same. 12875a4e7760SRob Suderman bool allTheSame = true; 128809349303SJacques Pienaar for (int i = 1, s = inputShape.getRank(); i < s; i++) { 128909349303SJacques Pienaar if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { 12905a4e7760SRob Suderman allTheSame = false; 12915a4e7760SRob Suderman break; 12925a4e7760SRob Suderman } 12935a4e7760SRob Suderman } 12945a4e7760SRob Suderman 12955a4e7760SRob Suderman // If all of the input dimensions are the same we don't care about the 12965a4e7760SRob Suderman // permutation. 12975a4e7760SRob Suderman if (allTheSame) { 129809349303SJacques Pienaar outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); 12995a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 13005a4e7760SRob Suderman return success(); 13015a4e7760SRob Suderman } 13025a4e7760SRob Suderman 1303399638f9SAliia Khasanova outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); 13045a4e7760SRob Suderman // If the permuations are a constant we can directly determine the output 13055a4e7760SRob Suderman // shape. 1306057fc8e7SAmanda Tang DenseIntElementsAttr attr; 1307057fc8e7SAmanda Tang if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) && 1308057fc8e7SAmanda Tang attr.getType().getRank() == 1) { 1309057fc8e7SAmanda Tang ShapeAdaptor permShape = attr; 131047793481SKai Sasaki // Constant permutation must be the same length as the input rank. 131147793481SKai Sasaki if (inputShape.getRank() != permShape.getRank()) 131247793481SKai Sasaki return emitOptionalError(location, 131347793481SKai Sasaki "constant permutation must be the same length" 131447793481SKai Sasaki " as the input rank"); 131547793481SKai Sasaki 131647793481SKai Sasaki // Constant permutation values must be within the input rank. 131747793481SKai Sasaki for (int i = 0, e = inputShape.getRank(); i < e; i++) { 131847793481SKai Sasaki if (inputShape.getRank() <= permShape.getDimSize(i)) 131947793481SKai Sasaki return failure(); 132047793481SKai Sasaki } 132147793481SKai Sasaki 132209349303SJacques Pienaar outputShape.reserve(inputShape.getRank()); 132309349303SJacques Pienaar for (int i = 0, s = inputShape.getRank(); i < s; i++) { 132409349303SJacques Pienaar outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); 13255a4e7760SRob Suderman } 13265a4e7760SRob Suderman } 13275a4e7760SRob Suderman 13285a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 13295a4e7760SRob Suderman return success(); 13305a4e7760SRob Suderman } 13315a4e7760SRob Suderman 13328190369eSFelix Schneider LogicalResult tosa::TransposeOp::verify() { 13338190369eSFelix Schneider TensorType inputType = getInput1().getType(); 13348190369eSFelix Schneider TensorType permType = getPerms().getType(); 13358190369eSFelix Schneider TensorType outputType = getOutput().getType(); 13368190369eSFelix Schneider 13378190369eSFelix Schneider if (permType.hasRank() && permType.getRank() != 1) 13388190369eSFelix Schneider return emitOpError() 13398190369eSFelix Schneider << "expected permutation tensor to be rank 1 but got rank " 13408190369eSFelix Schneider << permType.getRank(); 13418190369eSFelix Schneider if (inputType.hasRank() && permType.hasRank()) 13428190369eSFelix Schneider if (!permType.isDynamicDim(0) && 13438190369eSFelix Schneider permType.getDimSize(0) != inputType.getRank()) 13448190369eSFelix Schneider return emitOpError() << "expected permutation tensor dim 0 to have size " 13458190369eSFelix Schneider << inputType.getRank() 13468190369eSFelix Schneider << " (input rank) but got size " 13478190369eSFelix Schneider << permType.getDimSize(0); 13488190369eSFelix Schneider if (inputType.hasRank() && outputType.hasRank() && 13498190369eSFelix Schneider inputType.getRank() != outputType.getRank()) 13508190369eSFelix Schneider return emitOpError() 13518190369eSFelix Schneider << "expected input tensor rank to equal result tensor rank"; 13528190369eSFelix Schneider if (outputType.hasRank() && permType.hasRank()) 13538190369eSFelix Schneider if (!permType.isDynamicDim(0) && 13548190369eSFelix Schneider permType.getDimSize(0) != outputType.getRank()) 13558190369eSFelix Schneider return emitOpError() << "expected permutation tensor dim 0 to have size " 13568190369eSFelix Schneider << outputType.getRank() 13578190369eSFelix Schneider << " (output rank) but got size " 13588190369eSFelix Schneider << permType.getDimSize(0); 13598190369eSFelix Schneider 1360a54efdbdSArteen Abrishami SmallVector<int32_t> constantPerms; 13618190369eSFelix Schneider if (succeeded(getConstantPerms(constantPerms))) { 1362a54efdbdSArteen Abrishami // Assert that the permutation tensor has a rank, which means that the 1363a54efdbdSArteen Abrishami // rank has been verified above. 13648190369eSFelix Schneider assert(permType.hasRank() && 13658190369eSFelix Schneider "Unexpectedly found permutation tensor without rank"); 1366a54efdbdSArteen Abrishami if (!llvm::all_of(constantPerms, 1367a54efdbdSArteen Abrishami [&constantPerms](int32_t s) { 1368a54efdbdSArteen Abrishami return s >= 0 && 1369a54efdbdSArteen Abrishami static_cast<size_t>(s) < constantPerms.size(); 1370a54efdbdSArteen Abrishami }) || 1371a54efdbdSArteen Abrishami !isPermutationVector(llvm::to_vector(llvm::map_range( 1372a54efdbdSArteen Abrishami constantPerms, [](int32_t v) -> int64_t { return v; })))) 13738190369eSFelix Schneider return emitOpError() << "expected valid permutation tensor"; 1374c8b5d30fSDarshanRamakant 1375a54efdbdSArteen Abrishami // Verify that the types of the input and output tensors are properly 1376a54efdbdSArteen Abrishami // permuted. 1377a54efdbdSArteen Abrishami if (inputType.hasRank() && outputType.hasRank()) { 1378a54efdbdSArteen Abrishami assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) && 1379a54efdbdSArteen Abrishami inputType.getRank() == outputType.getRank()); 1380a54efdbdSArteen Abrishami 1381a54efdbdSArteen Abrishami for (auto i = 0; i < outputType.getRank(); i++) { 1382a54efdbdSArteen Abrishami if (inputType.isDynamicDim(constantPerms[i]) || 1383a54efdbdSArteen Abrishami outputType.isDynamicDim(i)) 1384a54efdbdSArteen Abrishami continue; 1385a54efdbdSArteen Abrishami 1386a54efdbdSArteen Abrishami if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i)) 1387a54efdbdSArteen Abrishami return emitOpError() 1388a54efdbdSArteen Abrishami << "expected output tensor dim " << i << " to match " 1389a54efdbdSArteen Abrishami << "input dim " << constantPerms[i] << " with value of " 1390a54efdbdSArteen Abrishami << inputType.getDimSize(constantPerms[i]); 1391a54efdbdSArteen Abrishami } 1392c8b5d30fSDarshanRamakant } 13938190369eSFelix Schneider } 13948190369eSFelix Schneider return success(); 13958190369eSFelix Schneider } 13968190369eSFelix Schneider 13975cd074faSMaya Amrami LogicalResult TransposeOp::reifyResultShapes( 13985cd074faSMaya Amrami OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 13995cd074faSMaya Amrami 1400a54efdbdSArteen Abrishami SmallVector<int32_t> transposePerms; 14015cd074faSMaya Amrami if (getConstantPerms(transposePerms).failed()) 14025cd074faSMaya Amrami return failure(); 14035cd074faSMaya Amrami 14045cd074faSMaya Amrami Value input = getInput1(); 1405d2353695SPeiming Liu auto inputType = cast<TensorType>(input.getType()); 14065cd074faSMaya Amrami 14075cd074faSMaya Amrami SmallVector<OpFoldResult> returnedDims(inputType.getRank()); 14085cd074faSMaya Amrami for (auto dim : transposePerms) { 1409a54efdbdSArteen Abrishami int32_t dimInInput = transposePerms[dim]; 14105cd074faSMaya Amrami if (inputType.isDynamicDim(dimInInput)) 14115cd074faSMaya Amrami returnedDims[dim] = 14125cd074faSMaya Amrami builder.create<tensor::DimOp>(getLoc(), input, dimInInput) 14135cd074faSMaya Amrami .getResult(); 14145cd074faSMaya Amrami else 14155cd074faSMaya Amrami returnedDims[dim] = 14165cd074faSMaya Amrami builder.getIndexAttr(inputType.getDimSize(dimInInput)); 14175cd074faSMaya Amrami } 14185cd074faSMaya Amrami 14195cd074faSMaya Amrami reifiedReturnShapes.emplace_back(std::move(returnedDims)); 14205cd074faSMaya Amrami return success(); 14215cd074faSMaya Amrami } 14225cd074faSMaya Amrami 14235a4e7760SRob Suderman LogicalResult tosa::GatherOp::inferReturnTypeComponents( 142422426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1425057fc8e7SAmanda Tang GatherOp::Adaptor adaptor, 14265a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 14275a4e7760SRob Suderman llvm::SmallVector<int64_t> outputShape; 1428399638f9SAliia Khasanova outputShape.resize(3, ShapedType::kDynamic); 14295a4e7760SRob Suderman 1430057fc8e7SAmanda Tang ShapeAdaptor valuesShape(adaptor.getValues().getType()); 143109349303SJacques Pienaar if (valuesShape.hasRank()) { 143209349303SJacques Pienaar outputShape[0] = valuesShape.getDimSize(0); 143309349303SJacques Pienaar outputShape[2] = valuesShape.getDimSize(2); 14345a4e7760SRob Suderman } 14355a4e7760SRob Suderman 1436057fc8e7SAmanda Tang ShapeAdaptor indicesShape(adaptor.getIndices().getType()); 143709349303SJacques Pienaar if (indicesShape.hasRank()) { 1438399638f9SAliia Khasanova if (outputShape[0] == ShapedType::kDynamic) 143909349303SJacques Pienaar outputShape[0] = indicesShape.getDimSize(0); 1440399638f9SAliia Khasanova if (outputShape[1] == ShapedType::kDynamic) 144109349303SJacques Pienaar outputShape[1] = indicesShape.getDimSize(1); 14425a4e7760SRob Suderman } 14435a4e7760SRob Suderman 14445a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 14455a4e7760SRob Suderman return success(); 14465a4e7760SRob Suderman } 14475a4e7760SRob Suderman 1448143edecaSRob Suderman LogicalResult tosa::ResizeOp::inferReturnTypeComponents( 144922426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1450057fc8e7SAmanda Tang ResizeOp::Adaptor adaptor, 1451143edecaSRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1452143edecaSRob Suderman llvm::SmallVector<int64_t, 4> outputShape; 1453399638f9SAliia Khasanova outputShape.resize(4, ShapedType::kDynamic); 1454143edecaSRob Suderman 1455057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 1456ff23599aSTatWai Chong if (!inputShape.hasRank()) 1457ff23599aSTatWai Chong return failure(); 1458ff23599aSTatWai Chong 145909349303SJacques Pienaar outputShape[0] = inputShape.getDimSize(0); 146009349303SJacques Pienaar outputShape[3] = inputShape.getDimSize(3); 1461fb4cedccSAliia Khasanova int64_t inputHeight = inputShape.getDimSize(1); 1462fb4cedccSAliia Khasanova int64_t inputWidth = inputShape.getDimSize(2); 1463143edecaSRob Suderman 1464399638f9SAliia Khasanova if ((inputHeight == ShapedType::kDynamic) || 1465399638f9SAliia Khasanova (inputWidth == ShapedType::kDynamic)) 1466ff23599aSTatWai Chong return failure(); 1467143edecaSRob Suderman 146811030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale(); 146911030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset(); 147011030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder(); 1471143edecaSRob Suderman 1472ff23599aSTatWai Chong // Compute the output shape based on attributes: scale, offset, and border. 1473ff23599aSTatWai Chong outputShape[1] = 1474ff23599aSTatWai Chong (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / 1475ff23599aSTatWai Chong scaleInt[1]) + 1476ff23599aSTatWai Chong 1; 1477143edecaSRob Suderman 1478ff23599aSTatWai Chong outputShape[2] = 1479ff23599aSTatWai Chong (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / 1480ff23599aSTatWai Chong scaleInt[3]) + 1481ff23599aSTatWai Chong 1; 1482143edecaSRob Suderman 1483143edecaSRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1484143edecaSRob Suderman return success(); 1485143edecaSRob Suderman } 1486143edecaSRob Suderman 14875a4e7760SRob Suderman LogicalResult tosa::ScatterOp::inferReturnTypeComponents( 148822426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1489057fc8e7SAmanda Tang ScatterOp::Adaptor adaptor, 14905a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 14915a4e7760SRob Suderman llvm::SmallVector<int64_t> outputShape; 1492399638f9SAliia Khasanova outputShape.resize(3, ShapedType::kDynamic); 14935a4e7760SRob Suderman 1494057fc8e7SAmanda Tang ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType()); 149509349303SJacques Pienaar if (valuesInShape.hasRank()) { 149609349303SJacques Pienaar outputShape[0] = valuesInShape.getDimSize(0); 149709349303SJacques Pienaar outputShape[1] = valuesInShape.getDimSize(1); 149809349303SJacques Pienaar outputShape[2] = valuesInShape.getDimSize(2); 14995a4e7760SRob Suderman } 15005a4e7760SRob Suderman 1501057fc8e7SAmanda Tang ShapeAdaptor indicesShape(adaptor.getIndices().getType()); 150209349303SJacques Pienaar if (indicesShape.hasRank()) { 1503399638f9SAliia Khasanova if (outputShape[0] == ShapedType::kDynamic) 150409349303SJacques Pienaar outputShape[0] = indicesShape.getDimSize(0); 15055a4e7760SRob Suderman } 15065a4e7760SRob Suderman 1507057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 150809349303SJacques Pienaar if (inputShape.hasRank()) { 1509399638f9SAliia Khasanova if (outputShape[0] == ShapedType::kDynamic) 151009349303SJacques Pienaar outputShape[0] = inputShape.getDimSize(0); 1511399638f9SAliia Khasanova if (outputShape[2] == ShapedType::kDynamic) 151209349303SJacques Pienaar outputShape[2] = inputShape.getDimSize(2); 15135a4e7760SRob Suderman } 15145a4e7760SRob Suderman 15155a4e7760SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 15165a4e7760SRob Suderman return success(); 15175a4e7760SRob Suderman } 15185a4e7760SRob Suderman 15195a4e7760SRob Suderman static LogicalResult ReduceInferReturnTypes( 15203500e110SAviad Cohen ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, 15215a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 15228a57bc09SFelix Schneider int64_t axisVal = axis.getValue().getSExtValue(); 15238a57bc09SFelix Schneider if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) { 15243500e110SAviad Cohen inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); 15255a4e7760SRob Suderman return success(); 15265a4e7760SRob Suderman } 15275a4e7760SRob Suderman 15285a4e7760SRob Suderman SmallVector<int64_t> outputShape; 152909349303SJacques Pienaar operandShape.getDims(outputShape); 15305a4e7760SRob Suderman outputShape[axisVal] = 1; 15313500e110SAviad Cohen inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); 15325a4e7760SRob Suderman return success(); 15335a4e7760SRob Suderman } 15345a4e7760SRob Suderman 15353500e110SAviad Cohen #define COMPATIBLE_RETURN_TYPES(OP) \ 15363500e110SAviad Cohen bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ 15373500e110SAviad Cohen if (l.size() != r.size() || l.size() != 1) \ 15383500e110SAviad Cohen return false; \ 15393500e110SAviad Cohen if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ 15403500e110SAviad Cohen return false; \ 15413500e110SAviad Cohen return succeeded(verifyCompatibleShape(l[0], r[0])); \ 15423500e110SAviad Cohen } 15433500e110SAviad Cohen 15445a4e7760SRob Suderman #define REDUCE_SHAPE_INFER(OP) \ 15455a4e7760SRob Suderman LogicalResult OP::inferReturnTypeComponents( \ 154622426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, \ 1547057fc8e7SAmanda Tang OP::Adaptor adaptor, \ 15485a4e7760SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ 15493500e110SAviad Cohen Type inputType = \ 1550057fc8e7SAmanda Tang llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \ 1551057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); \ 1552057fc8e7SAmanda Tang const Properties &prop = adaptor.getProperties(); \ 1553057fc8e7SAmanda Tang return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ 15545a4e7760SRob Suderman inferredReturnShapes); \ 15553500e110SAviad Cohen } \ 15563500e110SAviad Cohen COMPATIBLE_RETURN_TYPES(OP) 15575a4e7760SRob Suderman 15585a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceAllOp) 15595a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) 15605a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) 15615a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceMinOp) 15625a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceProdOp) 15635a4e7760SRob Suderman REDUCE_SHAPE_INFER(tosa::ReduceSumOp) 15645a4e7760SRob Suderman #undef REDUCE_SHAPE_INFER 15653500e110SAviad Cohen COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) 15663500e110SAviad Cohen #undef COMPATIBLE_RETURN_TYPES 15675a4e7760SRob Suderman 15688a57bc09SFelix Schneider template <typename T> 15698a57bc09SFelix Schneider static LogicalResult verifyReduceOp(T op) { 15708a57bc09SFelix Schneider // All TOSA reduce Ops have input, output and axis. 15718a57bc09SFelix Schneider TensorType inputType = op.getInput().getType(); 15728a57bc09SFelix Schneider TensorType outputType = op.getOutput().getType(); 15738a57bc09SFelix Schneider int32_t reduceAxis = op.getAxis(); 15748a57bc09SFelix Schneider 15758a57bc09SFelix Schneider if (reduceAxis < 0) { 15768a57bc09SFelix Schneider op.emitOpError("reduce axis must not be negative"); 15778a57bc09SFelix Schneider return failure(); 15788a57bc09SFelix Schneider } 15798a57bc09SFelix Schneider if (inputType.hasRank()) { 15808a57bc09SFelix Schneider int64_t inputRank = inputType.getRank(); 15818a57bc09SFelix Schneider // We allow for a special case where the input/output shape has rank 0 and 15828a57bc09SFelix Schneider // axis is also 0. 15838a57bc09SFelix Schneider if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) { 15848a57bc09SFelix Schneider op.emitOpError("expect input tensor rank (") 15858a57bc09SFelix Schneider << inputRank << ") to be larger than reduce axis (" << reduceAxis 15868a57bc09SFelix Schneider << ")"; 15878a57bc09SFelix Schneider return failure(); 15888a57bc09SFelix Schneider } 15898a57bc09SFelix Schneider } 15908a57bc09SFelix Schneider if (outputType.hasRank()) { 15918a57bc09SFelix Schneider int64_t outputRank = outputType.getRank(); 15928a57bc09SFelix Schneider if (inputType.hasRank() && outputRank != inputType.getRank()) { 15938a57bc09SFelix Schneider op.emitOpError( 15948a57bc09SFelix Schneider "expect output tensor rank to be equal to input tensor rank"); 15958a57bc09SFelix Schneider return failure(); 15968a57bc09SFelix Schneider } 15978a57bc09SFelix Schneider if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) { 15988a57bc09SFelix Schneider op.emitOpError("expect output tensor rank (") 15998a57bc09SFelix Schneider << outputRank << ") to be larger than reduce axis (" << reduceAxis 16008a57bc09SFelix Schneider << ")"; 16018a57bc09SFelix Schneider return failure(); 16028a57bc09SFelix Schneider } 1603a54efdbdSArteen Abrishami // We can only verify the reduced dimension size to be 1 if this is not 1604a54efdbdSArteen Abrishami // the special case of output rank == 0. 16058a57bc09SFelix Schneider if (outputRank != 0) { 16068a57bc09SFelix Schneider auto outputShape = outputType.getShape(); 16078a57bc09SFelix Schneider if (!outputType.isDynamicDim(reduceAxis) && 16088a57bc09SFelix Schneider outputShape[reduceAxis] != 1) { 16098a57bc09SFelix Schneider op.emitOpError("expect reduced dimension size to be 1, got ") 16108a57bc09SFelix Schneider << outputShape[reduceAxis]; 16118a57bc09SFelix Schneider return failure(); 16128a57bc09SFelix Schneider } 16138a57bc09SFelix Schneider } 16148a57bc09SFelix Schneider } 16158a57bc09SFelix Schneider return success(); 16168a57bc09SFelix Schneider } 16178a57bc09SFelix Schneider 16188a57bc09SFelix Schneider LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); } 16198a57bc09SFelix Schneider LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); } 16208a57bc09SFelix Schneider LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); } 16218a57bc09SFelix Schneider LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); } 16228a57bc09SFelix Schneider LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); } 16238a57bc09SFelix Schneider LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); } 16248a57bc09SFelix Schneider 16258dea784bSRob Suderman static LogicalResult NAryInferReturnTypes( 162609349303SJacques Pienaar const ValueShapeRange &operands, 16278dea784bSRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 16288dea784bSRob Suderman llvm::SmallVector<int64_t> outShape; 16298dea784bSRob Suderman if (resolveBroadcastShape(operands, outShape).failed()) { 16308dea784bSRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents()); 16318dea784bSRob Suderman } else { 16328dea784bSRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); 16338dea784bSRob Suderman } 16348dea784bSRob Suderman return success(); 16358dea784bSRob Suderman } 16368dea784bSRob Suderman 16378dea784bSRob Suderman #define NARY_SHAPE_INFER(OP) \ 16388dea784bSRob Suderman LogicalResult OP::inferReturnTypeComponents( \ 163922426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, \ 1640d425f589SJacques Pienaar ValueShapeRange operands, DictionaryAttr attributes, \ 16415e118f93SMehdi Amini OpaqueProperties properties, RegionRange regions, \ 16428dea784bSRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \ 16438dea784bSRob Suderman return NAryInferReturnTypes(operands, inferredReturnShapes); \ 16448dea784bSRob Suderman } 16458dea784bSRob Suderman 16468dea784bSRob Suderman NARY_SHAPE_INFER(tosa::AbsOp) 16478dea784bSRob Suderman NARY_SHAPE_INFER(tosa::AddOp) 16488dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) 16498dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseAndOp) 16508dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseOrOp) 16518dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseXorOp) 16528dea784bSRob Suderman NARY_SHAPE_INFER(tosa::BitwiseNotOp) 1653143edecaSRob Suderman NARY_SHAPE_INFER(tosa::CastOp) 16548dea784bSRob Suderman NARY_SHAPE_INFER(tosa::CeilOp) 16558dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ClampOp) 16568dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ClzOp) 1657d57f158aSJerry-Ge NARY_SHAPE_INFER(tosa::CosOp) 16588dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ExpOp) 16598dea784bSRob Suderman NARY_SHAPE_INFER(tosa::FloorOp) 16608dea784bSRob Suderman NARY_SHAPE_INFER(tosa::GreaterEqualOp) 16618dea784bSRob Suderman NARY_SHAPE_INFER(tosa::GreaterOp) 1662143edecaSRob Suderman NARY_SHAPE_INFER(tosa::IdentityOp) 166382383d5fSTai Ly NARY_SHAPE_INFER(tosa::IntDivOp) 16648dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogOp) 16658dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalAndOp) 16668dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) 16678dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalNotOp) 16688dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalOrOp) 16698dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) 16708dea784bSRob Suderman NARY_SHAPE_INFER(tosa::LogicalXorOp) 16718dea784bSRob Suderman NARY_SHAPE_INFER(tosa::MaximumOp) 16728dea784bSRob Suderman NARY_SHAPE_INFER(tosa::MinimumOp) 16738dea784bSRob Suderman NARY_SHAPE_INFER(tosa::MulOp) 16748dea784bSRob Suderman NARY_SHAPE_INFER(tosa::NegateOp) 16758dea784bSRob Suderman NARY_SHAPE_INFER(tosa::PowOp) 16768dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ReciprocalOp) 1677143edecaSRob Suderman NARY_SHAPE_INFER(tosa::RescaleOp) 16788dea784bSRob Suderman NARY_SHAPE_INFER(tosa::ReverseOp) 16798dea784bSRob Suderman NARY_SHAPE_INFER(tosa::RsqrtOp) 1680d57f158aSJerry-Ge NARY_SHAPE_INFER(tosa::SinOp) 16818dea784bSRob Suderman NARY_SHAPE_INFER(tosa::SelectOp) 16828dea784bSRob Suderman NARY_SHAPE_INFER(tosa::SubOp) 16838dea784bSRob Suderman NARY_SHAPE_INFER(tosa::TanhOp) 16841fef1f97SManupa Karunaratne NARY_SHAPE_INFER(tosa::ErfOp) 16858dea784bSRob Suderman NARY_SHAPE_INFER(tosa::SigmoidOp) 16868dea784bSRob Suderman #undef PRED_SHAPE_INFER 16878dea784bSRob Suderman 1688f2832c22SRob Suderman static LogicalResult poolingInferReturnTypes( 1689057fc8e7SAmanda Tang ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride, 1690057fc8e7SAmanda Tang ArrayRef<int64_t> pad, 1691f2832c22SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1692f2832c22SRob Suderman llvm::SmallVector<int64_t> outputShape; 1693399638f9SAliia Khasanova outputShape.resize(4, ShapedType::kDynamic); 1694f2832c22SRob Suderman 1695f2832c22SRob Suderman // We only know the rank if the input type is unranked. 169609349303SJacques Pienaar if (!inputShape) { 1697f2832c22SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1698f2832c22SRob Suderman return success(); 1699f2832c22SRob Suderman } 1700f2832c22SRob Suderman 1701f2832c22SRob Suderman // Batch and number of channels are identical for pooling layer. 170209349303SJacques Pienaar outputShape[0] = inputShape.getDimSize(0); 170309349303SJacques Pienaar outputShape[3] = inputShape.getDimSize(3); 1704f2832c22SRob Suderman 1705fb4cedccSAliia Khasanova int64_t height = inputShape.getDimSize(1); 1706fb4cedccSAliia Khasanova int64_t width = inputShape.getDimSize(2); 1707f2832c22SRob Suderman 1708fb4cedccSAliia Khasanova if (!ShapedType::isDynamic(height)) { 1709fb4cedccSAliia Khasanova int64_t padded = height + pad[0] + pad[1] - kernel[0]; 1710f2832c22SRob Suderman outputShape[1] = padded / stride[0] + 1; 1711f2832c22SRob Suderman } 1712f2832c22SRob Suderman 1713fb4cedccSAliia Khasanova if (!ShapedType::isDynamic(width)) { 1714fb4cedccSAliia Khasanova int64_t padded = width + pad[2] + pad[3] - kernel[1]; 1715f2832c22SRob Suderman outputShape[2] = padded / stride[1] + 1; 1716f2832c22SRob Suderman } 1717f2832c22SRob Suderman 1718f2832c22SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 1719f2832c22SRob Suderman return success(); 1720f2832c22SRob Suderman } 1721f2832c22SRob Suderman 172211dda1a2SRob Suderman LogicalResult Conv2DOp::inferReturnTypeComponents( 172322426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1724057fc8e7SAmanda Tang Conv2DOp::Adaptor adaptor, 172511dda1a2SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1726399638f9SAliia Khasanova llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic); 172711dda1a2SRob Suderman 1728399638f9SAliia Khasanova int64_t inputWidth = ShapedType::kDynamic; 1729399638f9SAliia Khasanova int64_t inputHeight = ShapedType::kDynamic; 1730399638f9SAliia Khasanova int64_t weightWidth = ShapedType::kDynamic; 1731399638f9SAliia Khasanova int64_t weightHeight = ShapedType::kDynamic; 173211dda1a2SRob Suderman 173311dda1a2SRob Suderman // Input shape describes input width/height and batch. 173409349303SJacques Pienaar 1735057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 173609349303SJacques Pienaar if (inputShape.hasRank()) { 173709349303SJacques Pienaar outputShape[0] = inputShape.getDimSize(0); 173809349303SJacques Pienaar inputHeight = inputShape.getDimSize(1); 173909349303SJacques Pienaar inputWidth = inputShape.getDimSize(2); 174011dda1a2SRob Suderman } 174111dda1a2SRob Suderman 174211dda1a2SRob Suderman // Weight shapes describes the filter width/height and the output channels. 1743057fc8e7SAmanda Tang ShapeAdaptor weightShape(adaptor.getWeight().getType()); 174409349303SJacques Pienaar if (weightShape.hasRank()) { 174509349303SJacques Pienaar outputShape[3] = weightShape.getDimSize(0); 174609349303SJacques Pienaar weightHeight = weightShape.getDimSize(1); 174709349303SJacques Pienaar weightWidth = weightShape.getDimSize(2); 174811dda1a2SRob Suderman } 174911dda1a2SRob Suderman 175011dda1a2SRob Suderman // Bias shape can describe the output channels. 1751057fc8e7SAmanda Tang ShapeAdaptor biasShape(adaptor.getBias().getType()); 175209349303SJacques Pienaar if (biasShape.hasRank()) { 175311dda1a2SRob Suderman outputShape[3] = ShapedType::isDynamic(outputShape[3]) 175409349303SJacques Pienaar ? biasShape.getDimSize(0) 175511dda1a2SRob Suderman : outputShape[3]; 175611dda1a2SRob Suderman } 175711dda1a2SRob Suderman 175811030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); 175911030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 176011030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> padding = adaptor.getPad(); 176111dda1a2SRob Suderman 176211dda1a2SRob Suderman if (!ShapedType::isDynamic(inputHeight) && 176311dda1a2SRob Suderman !ShapedType::isDynamic(weightHeight)) { 1764fb4cedccSAliia Khasanova int64_t inputSize = inputHeight + padding[0] + padding[1]; 1765fb4cedccSAliia Khasanova int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; 1766fb4cedccSAliia Khasanova int64_t unstridedResult = inputSize - filterSize + 1; 176711dda1a2SRob Suderman outputShape[1] = (unstridedResult - 1) / stride[0] + 1; 176811dda1a2SRob Suderman } 176911dda1a2SRob Suderman 177011dda1a2SRob Suderman if (!ShapedType::isDynamic(inputWidth) && 177111dda1a2SRob Suderman !ShapedType::isDynamic(weightWidth)) { 1772fb4cedccSAliia Khasanova int64_t inputSize = inputWidth + padding[2] + padding[3]; 1773fb4cedccSAliia Khasanova int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; 1774fb4cedccSAliia Khasanova int64_t unstridedResult = inputSize - filterSize + 1; 177511dda1a2SRob Suderman outputShape[2] = (unstridedResult - 1) / stride[1] + 1; 177611dda1a2SRob Suderman } 177711dda1a2SRob Suderman 177811dda1a2SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 177911dda1a2SRob Suderman return success(); 178011dda1a2SRob Suderman } 178111dda1a2SRob Suderman 1782360a03c9SJack Frankland LogicalResult Conv2DOp::verify() { 1783360a03c9SJack Frankland if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 1784360a03c9SJack Frankland return failure(); 1785360a03c9SJack Frankland return success(); 1786360a03c9SJack Frankland } 17871be88f5aSRiver Riddle 178811dda1a2SRob Suderman LogicalResult Conv3DOp::inferReturnTypeComponents( 178922426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1790057fc8e7SAmanda Tang Conv3DOp::Adaptor adaptor, 179111dda1a2SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1792399638f9SAliia Khasanova llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic); 179311dda1a2SRob Suderman 1794399638f9SAliia Khasanova int64_t inputWidth = ShapedType::kDynamic; 1795399638f9SAliia Khasanova int64_t inputHeight = ShapedType::kDynamic; 1796399638f9SAliia Khasanova int64_t inputDepth = ShapedType::kDynamic; 179711dda1a2SRob Suderman 1798399638f9SAliia Khasanova int64_t weightWidth = ShapedType::kDynamic; 1799399638f9SAliia Khasanova int64_t weightHeight = ShapedType::kDynamic; 1800399638f9SAliia Khasanova int64_t weightDepth = ShapedType::kDynamic; 180111dda1a2SRob Suderman 180211dda1a2SRob Suderman // Input shape describes input width/height and batch. 1803057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 180409349303SJacques Pienaar if (inputShape.hasRank()) { 180509349303SJacques Pienaar outputShape[0] = inputShape.getDimSize(0); 1806eb04f321STatWai Chong inputDepth = inputShape.getDimSize(1); 1807eb04f321STatWai Chong inputHeight = inputShape.getDimSize(2); 1808eb04f321STatWai Chong inputWidth = inputShape.getDimSize(3); 180911dda1a2SRob Suderman } 181011dda1a2SRob Suderman 181111dda1a2SRob Suderman // Weight shapes describes the filter width/height and the output channels. 1812057fc8e7SAmanda Tang ShapeAdaptor weightShape(adaptor.getWeight().getType()); 181309349303SJacques Pienaar if (weightShape.hasRank()) { 181409349303SJacques Pienaar outputShape[4] = weightShape.getDimSize(0); 1815eb04f321STatWai Chong weightDepth = weightShape.getDimSize(1); 1816eb04f321STatWai Chong weightHeight = weightShape.getDimSize(2); 1817eb04f321STatWai Chong weightWidth = weightShape.getDimSize(3); 181811dda1a2SRob Suderman } 181911dda1a2SRob Suderman 182011dda1a2SRob Suderman // Bias shape can describe the output channels. 1821057fc8e7SAmanda Tang ShapeAdaptor biasShape(adaptor.getBias().getType()); 1822eb04f321STatWai Chong if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { 1823eb04f321STatWai Chong outputShape[4] = biasShape.getDimSize(0); 182411dda1a2SRob Suderman } 182511dda1a2SRob Suderman 182611030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); 182711030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 182811030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> pad = adaptor.getPad(); 182911dda1a2SRob Suderman 1830eb04f321STatWai Chong if (!ShapedType::isDynamic(inputDepth) && 1831eb04f321STatWai Chong !ShapedType::isDynamic(weightDepth)) { 1832eb04f321STatWai Chong int32_t inputSize = inputDepth + pad[0] + pad[1]; 1833eb04f321STatWai Chong int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; 183411dda1a2SRob Suderman int32_t unstridedResult = inputSize - filterSize + 1; 183511dda1a2SRob Suderman outputShape[1] = (unstridedResult - 1) / stride[0] + 1; 183611dda1a2SRob Suderman } 183711dda1a2SRob Suderman 1838eb04f321STatWai Chong if (!ShapedType::isDynamic(inputHeight) && 1839eb04f321STatWai Chong !ShapedType::isDynamic(weightHeight)) { 1840eb04f321STatWai Chong int32_t inputSize = inputHeight + pad[2] + pad[3]; 1841eb04f321STatWai Chong int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; 184211dda1a2SRob Suderman int32_t unstridedResult = inputSize - filterSize + 1; 184311dda1a2SRob Suderman outputShape[2] = (unstridedResult - 1) / stride[1] + 1; 184411dda1a2SRob Suderman } 184511dda1a2SRob Suderman 1846eb04f321STatWai Chong if (!ShapedType::isDynamic(inputWidth) && 1847eb04f321STatWai Chong !ShapedType::isDynamic(weightWidth)) { 1848eb04f321STatWai Chong int32_t inputSize = inputWidth + pad[4] + pad[5]; 1849eb04f321STatWai Chong int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; 185011dda1a2SRob Suderman int32_t unstridedResult = inputSize - filterSize + 1; 185111dda1a2SRob Suderman outputShape[3] = (unstridedResult - 1) / stride[2] + 1; 185211dda1a2SRob Suderman } 185311dda1a2SRob Suderman 185411dda1a2SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 185511dda1a2SRob Suderman return success(); 185611dda1a2SRob Suderman } 185711dda1a2SRob Suderman 1858360a03c9SJack Frankland LogicalResult Conv3DOp::verify() { 1859360a03c9SJack Frankland if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 1860360a03c9SJack Frankland return failure(); 1861360a03c9SJack Frankland return success(); 1862360a03c9SJack Frankland } 18631be88f5aSRiver Riddle 1864f2832c22SRob Suderman LogicalResult AvgPool2dOp::inferReturnTypeComponents( 186522426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1866057fc8e7SAmanda Tang AvgPool2dOp::Adaptor adaptor, 1867f2832c22SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1868057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 1869057fc8e7SAmanda Tang const Properties &prop = adaptor.getProperties(); 1870057fc8e7SAmanda Tang return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, 1871057fc8e7SAmanda Tang inferredReturnShapes); 1872f2832c22SRob Suderman } 1873f2832c22SRob Suderman 1874f2832c22SRob Suderman LogicalResult MaxPool2dOp::inferReturnTypeComponents( 187522426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1876057fc8e7SAmanda Tang MaxPool2dOp::Adaptor adaptor, 1877f2832c22SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1878057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 1879057fc8e7SAmanda Tang const Properties &prop = adaptor.getProperties(); 1880057fc8e7SAmanda Tang return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, 1881057fc8e7SAmanda Tang inferredReturnShapes); 1882f2832c22SRob Suderman } 1883f2832c22SRob Suderman 188411dda1a2SRob Suderman LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( 188522426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1886057fc8e7SAmanda Tang DepthwiseConv2DOp::Adaptor adaptor, 188711dda1a2SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1888399638f9SAliia Khasanova llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic); 188911dda1a2SRob Suderman 1890399638f9SAliia Khasanova int64_t inputWidth = ShapedType::kDynamic; 1891399638f9SAliia Khasanova int64_t inputHeight = ShapedType::kDynamic; 1892399638f9SAliia Khasanova int64_t inputChannels = ShapedType::kDynamic; 189311dda1a2SRob Suderman 1894399638f9SAliia Khasanova int64_t weightWidth = ShapedType::kDynamic; 1895399638f9SAliia Khasanova int64_t weightHeight = ShapedType::kDynamic; 1896399638f9SAliia Khasanova int64_t depthChannels = ShapedType::kDynamic; 189711dda1a2SRob Suderman 189811dda1a2SRob Suderman // Input shape describes input width/height and batch. 1899057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 190009349303SJacques Pienaar if (inputShape.hasRank()) { 190109349303SJacques Pienaar outputShape[0] = inputShape.getDimSize(0); 190209349303SJacques Pienaar inputHeight = inputShape.getDimSize(1); 190309349303SJacques Pienaar inputWidth = inputShape.getDimSize(2); 190409349303SJacques Pienaar inputChannels = inputShape.getDimSize(3); 190511dda1a2SRob Suderman } 190611dda1a2SRob Suderman 190711dda1a2SRob Suderman // Weight shapes describes the filter width/height and the output channels. 1908057fc8e7SAmanda Tang ShapeAdaptor weightShape(adaptor.getWeight().getType()); 190909349303SJacques Pienaar if (weightShape.hasRank()) { 191009349303SJacques Pienaar weightHeight = weightShape.getDimSize(0); 191109349303SJacques Pienaar weightWidth = weightShape.getDimSize(1); 191211dda1a2SRob Suderman inputChannels = ShapedType::isDynamic(inputChannels) 191309349303SJacques Pienaar ? weightShape.getDimSize(2) 191411dda1a2SRob Suderman : inputChannels; 191509349303SJacques Pienaar depthChannels = weightShape.getDimSize(3); 191611dda1a2SRob Suderman } 191711dda1a2SRob Suderman 191811dda1a2SRob Suderman // If both inputChannels and depthChannels are available we can determine 191911dda1a2SRob Suderman // the output channels. 192011dda1a2SRob Suderman if (!ShapedType::isDynamic(inputChannels) && 192111dda1a2SRob Suderman !ShapedType::isDynamic(depthChannels)) { 192211dda1a2SRob Suderman outputShape[3] = inputChannels * depthChannels; 192311dda1a2SRob Suderman } 192411dda1a2SRob Suderman 192511dda1a2SRob Suderman // Bias shape can describe the output channels. 1926057fc8e7SAmanda Tang ShapeAdaptor biasShape(adaptor.getBias().getType()); 192709349303SJacques Pienaar if (biasShape.hasRank()) { 192811dda1a2SRob Suderman outputShape[3] = ShapedType::isDynamic(outputShape[3]) 192909349303SJacques Pienaar ? biasShape.getDimSize(0) 193011dda1a2SRob Suderman : outputShape[3]; 193111dda1a2SRob Suderman } 193211dda1a2SRob Suderman 193311030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); 193411030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> padding = adaptor.getPad(); 193511030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 193611dda1a2SRob Suderman 193711dda1a2SRob Suderman if (!ShapedType::isDynamic(inputHeight) && 193811dda1a2SRob Suderman !ShapedType::isDynamic(weightHeight)) { 1939fb4cedccSAliia Khasanova int64_t inputSize = inputHeight + padding[0] + padding[1]; 1940fb4cedccSAliia Khasanova int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; 1941fb4cedccSAliia Khasanova int64_t unstridedResult = inputSize - filterSize + 1; 194211dda1a2SRob Suderman outputShape[1] = (unstridedResult - 1) / stride[0] + 1; 194311dda1a2SRob Suderman } 194411dda1a2SRob Suderman 194511dda1a2SRob Suderman if (!ShapedType::isDynamic(inputWidth) && 194611dda1a2SRob Suderman !ShapedType::isDynamic(weightWidth)) { 1947fb4cedccSAliia Khasanova int64_t inputSize = inputWidth + padding[2] + padding[3]; 1948fb4cedccSAliia Khasanova int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; 1949fb4cedccSAliia Khasanova int64_t unstridedResult = inputSize - filterSize + 1; 195011dda1a2SRob Suderman outputShape[2] = (unstridedResult - 1) / stride[1] + 1; 195111dda1a2SRob Suderman } 195211dda1a2SRob Suderman 195311dda1a2SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 195411dda1a2SRob Suderman return success(); 195511dda1a2SRob Suderman } 195611dda1a2SRob Suderman 1957360a03c9SJack Frankland LogicalResult DepthwiseConv2DOp::verify() { 1958360a03c9SJack Frankland if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 1959360a03c9SJack Frankland return failure(); 1960360a03c9SJack Frankland return success(); 1961360a03c9SJack Frankland } 19621be88f5aSRiver Riddle 196311dda1a2SRob Suderman LogicalResult TransposeConv2DOp::inferReturnTypeComponents( 196422426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 1965057fc8e7SAmanda Tang TransposeConv2DOp::Adaptor adaptor, 196611dda1a2SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 196711030c7dSAlexander Shaposhnikov // outputShape is mutable. 196811030c7dSAlexander Shaposhnikov llvm::SmallVector<int64_t> outputShape = 196911030c7dSAlexander Shaposhnikov convertToMlirShape(adaptor.getOutShape()); 197011dda1a2SRob Suderman 1971399638f9SAliia Khasanova int64_t inputWidth = ShapedType::kDynamic; 1972399638f9SAliia Khasanova int64_t inputHeight = ShapedType::kDynamic; 1973399638f9SAliia Khasanova int64_t weightWidth = ShapedType::kDynamic; 1974399638f9SAliia Khasanova int64_t weightHeight = ShapedType::kDynamic; 197511dda1a2SRob Suderman 197611dda1a2SRob Suderman // Input shape describes input width/height and batch. 1977057fc8e7SAmanda Tang ShapeAdaptor inputShape(adaptor.getInput().getType()); 197809349303SJacques Pienaar if (inputShape.hasRank()) { 197911dda1a2SRob Suderman outputShape[0] = ShapedType::isDynamic(outputShape[0]) 198009349303SJacques Pienaar ? inputShape.getDimSize(0) 198111dda1a2SRob Suderman : outputShape[0]; 198209349303SJacques Pienaar inputHeight = inputShape.getDimSize(1); 198309349303SJacques Pienaar inputWidth = inputShape.getDimSize(2); 198411dda1a2SRob Suderman } 198511dda1a2SRob Suderman 198611dda1a2SRob Suderman // Weight shapes describes the filter width/height and the output channels. 1987057fc8e7SAmanda Tang ShapeAdaptor weightShape(adaptor.getFilter().getType()); 198809349303SJacques Pienaar if (weightShape.hasRank()) { 198911dda1a2SRob Suderman outputShape[3] = ShapedType::isDynamic(outputShape[3]) 199009349303SJacques Pienaar ? weightShape.getDimSize(0) 199111dda1a2SRob Suderman : outputShape[3]; 199209349303SJacques Pienaar weightHeight = weightShape.getDimSize(1); 199309349303SJacques Pienaar weightWidth = weightShape.getDimSize(2); 199411dda1a2SRob Suderman } 199511dda1a2SRob Suderman 199611dda1a2SRob Suderman // Bias shape can describe the output channels. 1997057fc8e7SAmanda Tang ShapeAdaptor biasShape(adaptor.getInput().getType()); 199809349303SJacques Pienaar if (biasShape.hasRank()) { 199911dda1a2SRob Suderman outputShape[3] = ShapedType::isDynamic(outputShape[3]) 200009349303SJacques Pienaar ? biasShape.getDimSize(0) 200111dda1a2SRob Suderman : outputShape[3]; 200211dda1a2SRob Suderman } 200311dda1a2SRob Suderman 200411030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> padding = adaptor.getOutPad(); 200511030c7dSAlexander Shaposhnikov llvm::ArrayRef<int64_t> stride = adaptor.getStride(); 200611dda1a2SRob Suderman 200711dda1a2SRob Suderman if (!ShapedType::isDynamic(inputHeight) && 200811dda1a2SRob Suderman !ShapedType::isDynamic(weightHeight)) { 2009fb4cedccSAliia Khasanova int64_t calculateSize = 2010fcbf3fafSRob Suderman (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; 2011fb4cedccSAliia Khasanova outputShape[1] = 2012fb4cedccSAliia Khasanova ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; 201311dda1a2SRob Suderman } 201411dda1a2SRob Suderman 201511dda1a2SRob Suderman if (!ShapedType::isDynamic(inputWidth) && 201611dda1a2SRob Suderman !ShapedType::isDynamic(weightWidth)) { 2017fb4cedccSAliia Khasanova int64_t calculateSize = 2018fcbf3fafSRob Suderman (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; 2019fb4cedccSAliia Khasanova outputShape[2] = 2020fb4cedccSAliia Khasanova ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; 202111dda1a2SRob Suderman } 202211dda1a2SRob Suderman 202311dda1a2SRob Suderman inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); 202411dda1a2SRob Suderman return success(); 202511dda1a2SRob Suderman } 202611dda1a2SRob Suderman 2027360a03c9SJack Frankland LogicalResult TransposeConv2DOp::verify() { 2028360a03c9SJack Frankland if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) 2029360a03c9SJack Frankland return failure(); 2030360a03c9SJack Frankland return success(); 2031360a03c9SJack Frankland } 2032360a03c9SJack Frankland 20331b00b94fSRob Suderman LogicalResult IfOp::inferReturnTypeComponents( 203422426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 2035057fc8e7SAmanda Tang IfOp::Adaptor adaptor, 20361b00b94fSRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 20371b00b94fSRob Suderman llvm::SmallVector<tosa::YieldOp> yieldOps; 2038057fc8e7SAmanda Tang for (Region *region : adaptor.getRegions()) { 20391b00b94fSRob Suderman for (auto &block : *region) 20401b00b94fSRob Suderman if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator())) 20411b00b94fSRob Suderman yieldOps.push_back(returnOp); 20421b00b94fSRob Suderman } 20431b00b94fSRob Suderman 20441b00b94fSRob Suderman if (yieldOps.empty()) 20451b00b94fSRob Suderman return failure(); 20461b00b94fSRob Suderman 20471b00b94fSRob Suderman // Get the initial type information for the yield op. 20481b00b94fSRob Suderman llvm::SmallVector<ValueKnowledge> resultKnowledge; 20491b00b94fSRob Suderman resultKnowledge.reserve(yieldOps.front().getNumOperands()); 20501b00b94fSRob Suderman for (auto operand : yieldOps.front().getOperands()) { 20511b00b94fSRob Suderman resultKnowledge.push_back( 20521b00b94fSRob Suderman ValueKnowledge::getKnowledgeFromType(operand.getType())); 20531b00b94fSRob Suderman } 20541b00b94fSRob Suderman 20551b00b94fSRob Suderman for (auto yieldOp : yieldOps) { 20561b00b94fSRob Suderman if (resultKnowledge.size() != yieldOp.getNumOperands()) 20571b00b94fSRob Suderman return failure(); 20581b00b94fSRob Suderman 205989de9cc8SMehdi Amini for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 20601b00b94fSRob Suderman int32_t index = it.index(); 20611b00b94fSRob Suderman auto meet = ValueKnowledge::meet( 20621b00b94fSRob Suderman resultKnowledge[index], 20631b00b94fSRob Suderman ValueKnowledge::getKnowledgeFromType(it.value().getType())); 20641b00b94fSRob Suderman if (!meet) 20651b00b94fSRob Suderman continue; 20661b00b94fSRob Suderman resultKnowledge[index] = meet; 20671b00b94fSRob Suderman } 20681b00b94fSRob Suderman } 20691b00b94fSRob Suderman 207009349303SJacques Pienaar for (const ValueKnowledge &result : resultKnowledge) { 2071b0532286SRob Suderman inferredReturnShapes.push_back(result.getShapedTypeComponents()); 20721b00b94fSRob Suderman } 2073b0532286SRob Suderman 2074b0532286SRob Suderman return success(); 2075b0532286SRob Suderman } 2076b0532286SRob Suderman 2077b0532286SRob Suderman LogicalResult WhileOp::inferReturnTypeComponents( 207822426110SRamkumar Ramachandra MLIRContext *context, ::std::optional<Location> location, 2079057fc8e7SAmanda Tang WhileOp::Adaptor adaptor, 2080b0532286SRob Suderman SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 2081b0532286SRob Suderman llvm::SmallVector<tosa::YieldOp> yieldOps; 2082057fc8e7SAmanda Tang for (auto &block : adaptor.getBody()) 2083b0532286SRob Suderman if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator())) 2084b0532286SRob Suderman yieldOps.push_back(returnOp); 2085b0532286SRob Suderman 2086b0532286SRob Suderman // TOSA's while must have a tosa.yield as its terminator. If not found this 2087b0532286SRob Suderman // tosa.while is invalid. 2088b0532286SRob Suderman if (yieldOps.empty()) 2089b0532286SRob Suderman return failure(); 2090b0532286SRob Suderman 2091b0532286SRob Suderman // Get the initial type information from the operand types. 2092b0532286SRob Suderman llvm::SmallVector<ValueKnowledge> resultKnowledge; 2093b0532286SRob Suderman resultKnowledge.reserve(yieldOps.front().getNumOperands()); 2094b0532286SRob Suderman for (auto operand : yieldOps.front().getOperands()) { 2095b0532286SRob Suderman resultKnowledge.push_back( 2096b0532286SRob Suderman ValueKnowledge::getKnowledgeFromType(operand.getType())); 2097b0532286SRob Suderman } 2098b0532286SRob Suderman 2099b0532286SRob Suderman for (auto yieldOp : yieldOps) { 2100b0532286SRob Suderman if (resultKnowledge.size() != yieldOp.getNumOperands()) 2101b0532286SRob Suderman return failure(); 2102b0532286SRob Suderman 210389de9cc8SMehdi Amini for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 2104b0532286SRob Suderman int32_t index = it.index(); 2105b0532286SRob Suderman if (auto meet = ValueKnowledge::meet( 2106b0532286SRob Suderman resultKnowledge[index], 2107b0532286SRob Suderman ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { 2108b0532286SRob Suderman resultKnowledge[index] = meet; 21095161835dSlipracer } 2110b0532286SRob Suderman } 2111b0532286SRob Suderman } 2112b0532286SRob Suderman 2113b0532286SRob Suderman for (const ValueKnowledge &result : resultKnowledge) { 2114b0532286SRob Suderman inferredReturnShapes.push_back(result.getShapedTypeComponents()); 21151b00b94fSRob Suderman } 21161b00b94fSRob Suderman 21171b00b94fSRob Suderman return success(); 21181b00b94fSRob Suderman } 21191b00b94fSRob Suderman 212081566001SJakub Kuderski std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { 2121c1fa60b4STres Popp if (auto vt = llvm::dyn_cast<VectorType>(getType())) 212281566001SJakub Kuderski return llvm::to_vector<4>(vt.getShape()); 212381566001SJakub Kuderski return std::nullopt; 212481566001SJakub Kuderski } 212581566001SJakub Kuderski 2126da4d191fSTatWai Chong // parse and print of IfOp refer to the implementation of SCF dialect. 2127da4d191fSTatWai Chong ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { 2128da4d191fSTatWai Chong // Create the regions for 'then'. 2129da4d191fSTatWai Chong result.regions.reserve(2); 2130da4d191fSTatWai Chong Region *thenRegion = result.addRegion(); 2131da4d191fSTatWai Chong Region *elseRegion = result.addRegion(); 2132da4d191fSTatWai Chong 2133da4d191fSTatWai Chong auto &builder = parser.getBuilder(); 2134da4d191fSTatWai Chong OpAsmParser::UnresolvedOperand cond; 2135da4d191fSTatWai Chong // Create a i1 tensor type for the boolean condition. 2136da4d191fSTatWai Chong Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); 2137da4d191fSTatWai Chong if (parser.parseOperand(cond) || 2138da4d191fSTatWai Chong parser.resolveOperand(cond, i1Type, result.operands)) 2139da4d191fSTatWai Chong return failure(); 2140da4d191fSTatWai Chong // Parse optional results type list. 2141da4d191fSTatWai Chong if (parser.parseOptionalArrowTypeList(result.types)) 2142da4d191fSTatWai Chong return failure(); 2143da4d191fSTatWai Chong // Parse the 'then' region. 2144da4d191fSTatWai Chong if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) 2145da4d191fSTatWai Chong return failure(); 2146da4d191fSTatWai Chong 2147da4d191fSTatWai Chong // If we find an 'else' keyword then parse the 'else' region. 2148da4d191fSTatWai Chong if (!parser.parseOptionalKeyword("else")) { 2149da4d191fSTatWai Chong if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) 2150da4d191fSTatWai Chong return failure(); 2151da4d191fSTatWai Chong } 2152da4d191fSTatWai Chong 2153da4d191fSTatWai Chong // Parse the optional attribute list. 2154da4d191fSTatWai Chong if (parser.parseOptionalAttrDict(result.attributes)) 2155da4d191fSTatWai Chong return failure(); 2156da4d191fSTatWai Chong return success(); 2157da4d191fSTatWai Chong } 2158da4d191fSTatWai Chong 2159da4d191fSTatWai Chong void IfOp::print(OpAsmPrinter &p) { 2160da4d191fSTatWai Chong bool printBlockTerminators = false; 2161da4d191fSTatWai Chong 2162da4d191fSTatWai Chong p << " " << getCond(); 2163da4d191fSTatWai Chong if (!getResults().empty()) { 2164da4d191fSTatWai Chong p << " -> (" << getResultTypes() << ")"; 2165da4d191fSTatWai Chong // Print yield explicitly if the op defines values. 2166da4d191fSTatWai Chong printBlockTerminators = true; 2167da4d191fSTatWai Chong } 2168da4d191fSTatWai Chong p << ' '; 2169da4d191fSTatWai Chong p.printRegion(getThenBranch(), 2170da4d191fSTatWai Chong /*printEntryBlockArgs=*/false, 2171da4d191fSTatWai Chong /*printBlockTerminators=*/printBlockTerminators); 2172da4d191fSTatWai Chong 2173da4d191fSTatWai Chong // Print the 'else' regions if it exists and has a block. 2174da4d191fSTatWai Chong auto &elseRegion = getElseBranch(); 2175da4d191fSTatWai Chong if (!elseRegion.empty()) { 2176da4d191fSTatWai Chong p << " else "; 2177da4d191fSTatWai Chong p.printRegion(elseRegion, 2178da4d191fSTatWai Chong /*printEntryBlockArgs=*/false, 2179da4d191fSTatWai Chong /*printBlockTerminators=*/printBlockTerminators); 2180da4d191fSTatWai Chong } 2181da4d191fSTatWai Chong 2182da4d191fSTatWai Chong p.printOptionalAttrDict((*this)->getAttrs()); 2183da4d191fSTatWai Chong } 2184da4d191fSTatWai Chong 21856ed2d30dSFelix Schneider LogicalResult ReverseOp::verify() { 2186c6876b4eSJerry-Ge TensorType inputType = getInput1().getType(); 21876ed2d30dSFelix Schneider TensorType outputType = getOutput().getType(); 21886ed2d30dSFelix Schneider int32_t reverseAxis = getAxis(); 21896ed2d30dSFelix Schneider 21906ed2d30dSFelix Schneider if (reverseAxis < 0) 21916ed2d30dSFelix Schneider return emitOpError("expected non-negative reverse axis"); 21926ed2d30dSFelix Schneider if (inputType.hasRank()) { 21936ed2d30dSFelix Schneider int64_t inputRank = inputType.getRank(); 21946ed2d30dSFelix Schneider // We allow for a special case where the input/output shape has rank 0 and 21956ed2d30dSFelix Schneider // axis is also 0. 21966ed2d30dSFelix Schneider if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0)) 21976ed2d30dSFelix Schneider return emitOpError("expect input tensor rank (") 21986ed2d30dSFelix Schneider << inputRank << ") to be larger than reverse axis (" << reverseAxis 21996ed2d30dSFelix Schneider << ")"; 22006ed2d30dSFelix Schneider } 22016ed2d30dSFelix Schneider if (outputType.hasRank()) { 22026ed2d30dSFelix Schneider int64_t outputRank = outputType.getRank(); 22036ed2d30dSFelix Schneider if (inputType.hasRank() && outputRank != inputType.getRank()) 22046ed2d30dSFelix Schneider return emitOpError( 22056ed2d30dSFelix Schneider "expect output tensor rank to be equal to input tensor rank"); 22066ed2d30dSFelix Schneider if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0)) 22076ed2d30dSFelix Schneider return emitOpError("expect output tensor rank (") 22086ed2d30dSFelix Schneider << outputRank << ") to be larger than reverse axis (" 22096ed2d30dSFelix Schneider << reverseAxis << ")"; 22106ed2d30dSFelix Schneider } 22116ed2d30dSFelix Schneider return success(); 22126ed2d30dSFelix Schneider } 22136ed2d30dSFelix Schneider 2214da4d191fSTatWai Chong // parse and print of WhileOp refer to the implementation of SCF dialect. 2215da4d191fSTatWai Chong ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { 2216da4d191fSTatWai Chong SmallVector<OpAsmParser::Argument, 4> regionArgs; 2217da4d191fSTatWai Chong SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; 2218da4d191fSTatWai Chong Region *cond = result.addRegion(); 2219da4d191fSTatWai Chong Region *body = result.addRegion(); 2220da4d191fSTatWai Chong 2221da4d191fSTatWai Chong OptionalParseResult listResult = 2222da4d191fSTatWai Chong parser.parseOptionalAssignmentList(regionArgs, operands); 2223da4d191fSTatWai Chong if (listResult.has_value() && failed(listResult.value())) 2224da4d191fSTatWai Chong return failure(); 2225da4d191fSTatWai Chong 2226da4d191fSTatWai Chong FunctionType functionType; 2227da4d191fSTatWai Chong SMLoc typeLoc = parser.getCurrentLocation(); 2228da4d191fSTatWai Chong if (failed(parser.parseColonType(functionType))) 2229da4d191fSTatWai Chong return failure(); 2230da4d191fSTatWai Chong 2231da4d191fSTatWai Chong result.addTypes(functionType.getResults()); 2232da4d191fSTatWai Chong 2233da4d191fSTatWai Chong if (functionType.getNumInputs() != operands.size()) { 2234da4d191fSTatWai Chong return parser.emitError(typeLoc) 2235da4d191fSTatWai Chong << "expected as many input types as operands " 2236da4d191fSTatWai Chong << "(expected " << operands.size() << " got " 2237da4d191fSTatWai Chong << functionType.getNumInputs() << ")"; 2238da4d191fSTatWai Chong } 2239da4d191fSTatWai Chong 2240da4d191fSTatWai Chong // Resolve input operands. 2241da4d191fSTatWai Chong if (failed(parser.resolveOperands(operands, functionType.getInputs(), 2242da4d191fSTatWai Chong parser.getCurrentLocation(), 2243da4d191fSTatWai Chong result.operands))) 2244da4d191fSTatWai Chong return failure(); 2245da4d191fSTatWai Chong 2246da4d191fSTatWai Chong // Propagate the types into the region arguments. 2247da4d191fSTatWai Chong for (size_t i = 0, e = regionArgs.size(); i != e; ++i) 2248da4d191fSTatWai Chong regionArgs[i].type = functionType.getInput(i); 2249da4d191fSTatWai Chong 2250da4d191fSTatWai Chong return failure(parser.parseRegion(*cond, regionArgs) || 2251da4d191fSTatWai Chong parser.parseKeyword("do") || parser.parseRegion(*body) || 2252da4d191fSTatWai Chong parser.parseOptionalAttrDictWithKeyword(result.attributes)); 2253da4d191fSTatWai Chong } 2254da4d191fSTatWai Chong 2255da4d191fSTatWai Chong static void printInitializationList(OpAsmPrinter &parser, 2256da4d191fSTatWai Chong Block::BlockArgListType blocksArgs, 2257da4d191fSTatWai Chong ValueRange initializers, 2258da4d191fSTatWai Chong StringRef prefix = "") { 2259da4d191fSTatWai Chong assert(blocksArgs.size() == initializers.size() && 2260da4d191fSTatWai Chong "expected same length of arguments and initializers"); 2261da4d191fSTatWai Chong if (initializers.empty()) 2262da4d191fSTatWai Chong return; 2263da4d191fSTatWai Chong 2264da4d191fSTatWai Chong parser << prefix << '('; 2265da4d191fSTatWai Chong llvm::interleaveComma( 2266da4d191fSTatWai Chong llvm::zip(blocksArgs, initializers), parser, 2267da4d191fSTatWai Chong [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); 2268da4d191fSTatWai Chong parser << ")"; 2269da4d191fSTatWai Chong } 2270da4d191fSTatWai Chong 2271da4d191fSTatWai Chong void WhileOp::print(OpAsmPrinter &parser) { 2272da4d191fSTatWai Chong printInitializationList(parser, getCond().front().getArguments(), getInputs(), 2273da4d191fSTatWai Chong " "); 2274da4d191fSTatWai Chong parser << " : "; 2275da4d191fSTatWai Chong parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes()); 2276da4d191fSTatWai Chong parser << ' '; 2277da4d191fSTatWai Chong parser.printRegion(getCond(), /*printEntryBlockArgs=*/false); 2278da4d191fSTatWai Chong parser << " do "; 2279da4d191fSTatWai Chong parser.printRegion(getBody()); 2280da4d191fSTatWai Chong parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 2281da4d191fSTatWai Chong } 2282da4d191fSTatWai Chong 22838dea784bSRob Suderman //===----------------------------------------------------------------------===// 2284f09db6a3SJerry-Ge // TOSA Shape and Shape Operators Helper functions. 2285f09db6a3SJerry-Ge //===----------------------------------------------------------------------===// 2286f09db6a3SJerry-Ge 2287f09db6a3SJerry-Ge bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) { 2288f09db6a3SJerry-Ge return mlir::isa<tosa::shapeType>(t); 2289f09db6a3SJerry-Ge } 2290f09db6a3SJerry-Ge 2291f09db6a3SJerry-Ge LogicalResult 2292f09db6a3SJerry-Ge mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError, 2293f09db6a3SJerry-Ge int rank) { 2294f09db6a3SJerry-Ge if (rank < 0) 2295f09db6a3SJerry-Ge return emitError() << "invalid rank (must be >= 0): " << rank; 2296f09db6a3SJerry-Ge return success(); 2297f09db6a3SJerry-Ge } 2298f09db6a3SJerry-Ge 2299f09db6a3SJerry-Ge LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { 2300f09db6a3SJerry-Ge for (auto v : op->getOperands()) { 2301f09db6a3SJerry-Ge if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) { 2302f09db6a3SJerry-Ge Operation *definingOp = v.getDefiningOp(); 2303f09db6a3SJerry-Ge if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) { 2304f09db6a3SJerry-Ge return op->emitOpError("shape operand is not compile time resolvable"); 2305f09db6a3SJerry-Ge } 2306f09db6a3SJerry-Ge } 2307f09db6a3SJerry-Ge } 2308f09db6a3SJerry-Ge return success(); 2309f09db6a3SJerry-Ge } 2310f09db6a3SJerry-Ge 2311f09db6a3SJerry-Ge LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { 2312f09db6a3SJerry-Ge for (auto type : op->getOperandTypes()) { 2313f09db6a3SJerry-Ge if (!mlir::isa<mlir::tosa::shapeType>(type)) { 2314f09db6a3SJerry-Ge return op->emitOpError("must have operands with tosa shape type"); 2315f09db6a3SJerry-Ge } 2316f09db6a3SJerry-Ge } 2317f09db6a3SJerry-Ge for (auto type : op->getResultTypes()) { 2318f09db6a3SJerry-Ge if (!mlir::isa<mlir::tosa::shapeType>(type)) { 2319f09db6a3SJerry-Ge return op->emitOpError("must have result with tosa shape type"); 2320f09db6a3SJerry-Ge } 2321f09db6a3SJerry-Ge } 2322f09db6a3SJerry-Ge return success(); 2323f09db6a3SJerry-Ge } 2324f09db6a3SJerry-Ge 2325f09db6a3SJerry-Ge LogicalResult 2326f09db6a3SJerry-Ge OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { 2327f09db6a3SJerry-Ge if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || 2328f09db6a3SJerry-Ge failed(verifyTosaShapeOperator(op))) 2329f09db6a3SJerry-Ge return failure(); 2330f09db6a3SJerry-Ge 2331f09db6a3SJerry-Ge // delegate function that returns rank of shape type 2332f09db6a3SJerry-Ge auto getRank = [](const Type type) { 2333f09db6a3SJerry-Ge return mlir::cast<mlir::tosa::shapeType>(type).getRank(); 2334f09db6a3SJerry-Ge }; 2335f09db6a3SJerry-Ge auto operandTypes = op->getOperandTypes(); 2336f09db6a3SJerry-Ge auto resultTypes = op->getResultTypes(); 2337f09db6a3SJerry-Ge 2338f09db6a3SJerry-Ge auto rank = getRank(*op->getOperandTypes().begin()); 2339f09db6a3SJerry-Ge for (auto type : operandTypes) { 2340f09db6a3SJerry-Ge if (getRank(type) != rank) { 2341f09db6a3SJerry-Ge return op->emitOpError("operands don't have matching ranks"); 2342f09db6a3SJerry-Ge } 2343f09db6a3SJerry-Ge } 2344f09db6a3SJerry-Ge for (auto type : resultTypes) { 2345f09db6a3SJerry-Ge if (getRank(type) != rank) { 2346f09db6a3SJerry-Ge return op->emitOpError("result shape has different rank than operands"); 2347f09db6a3SJerry-Ge } 2348f09db6a3SJerry-Ge } 2349f09db6a3SJerry-Ge return success(); 2350f09db6a3SJerry-Ge } 2351f09db6a3SJerry-Ge 2352f09db6a3SJerry-Ge //===----------------------------------------------------------------------===// 2353f09db6a3SJerry-Ge // TOSA Shape Operators verify functions. 2354f09db6a3SJerry-Ge //===----------------------------------------------------------------------===// 2355f09db6a3SJerry-Ge 2356f09db6a3SJerry-Ge LogicalResult tosa::ConstShapeOp::verify() { 2357f09db6a3SJerry-Ge // check that number of elements in value attr equal to rank of result shape 2358f09db6a3SJerry-Ge auto count = getValue().getNumElements(); 2359f09db6a3SJerry-Ge auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank(); 2360f09db6a3SJerry-Ge if (!(count == rank || (count == 1 && rank == 0))) { 2361f09db6a3SJerry-Ge return emitOpError("expect number of elements in attribute value (") 2362f09db6a3SJerry-Ge << count << ") to be equal to the rank (" << rank 2363f09db6a3SJerry-Ge << ") for the result shape type"; 2364f09db6a3SJerry-Ge } 2365f09db6a3SJerry-Ge return success(); 2366f09db6a3SJerry-Ge } 2367f09db6a3SJerry-Ge 2368f09db6a3SJerry-Ge //===----------------------------------------------------------------------===// 2369f1182bd6SMogball // TOSA Attribute Definitions. 2370f1182bd6SMogball //===----------------------------------------------------------------------===// 2371f1182bd6SMogball 2372f1182bd6SMogball #define GET_ATTRDEF_CLASSES 2373f1182bd6SMogball #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" 2374f1182bd6SMogball 2375f1182bd6SMogball //===----------------------------------------------------------------------===// 2376f09db6a3SJerry-Ge // TOSA Type Definitions. 2377f09db6a3SJerry-Ge //===----------------------------------------------------------------------===// 2378f09db6a3SJerry-Ge #define GET_TYPEDEF_CLASSES 2379f09db6a3SJerry-Ge #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" 2380f09db6a3SJerry-Ge 2381f09db6a3SJerry-Ge //===----------------------------------------------------------------------===// 2382b2812113SSuraj Sudhir // TOSA Operator Definitions. 2383b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 2384b2812113SSuraj Sudhir 2385b2812113SSuraj Sudhir #define GET_OP_CLASSES 2386b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" 2387