//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // \file // This file implements the TOSA Specification: // https://developer.mlplatform.org/w/tosa/ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" #include using namespace mlir; using namespace mlir::tosa; #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" //===----------------------------------------------------------------------===// // Tosa dialect interface includes. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" namespace { #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" //===----------------------------------------------------------------------===// // Dialect Function Inliner Interface. //===----------------------------------------------------------------------===// struct TosaInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks. //===--------------------------------------------------------------------===// /// All operations can be inlined by default. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, IRMapping &map) const final { return true; } /// All regions with If and While parent operators can be inlined. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, IRMapping &map) const final { return (isa(dest->getParentOp()) || isa(dest->getParentOp())); } }; /// This class implements the bytecode interface for the Tosa dialect. struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { TosaDialectBytecodeInterface(Dialect *dialect) : BytecodeDialectInterface(dialect) {} //===--------------------------------------------------------------------===// // Attributes Attribute readAttribute(DialectBytecodeReader &reader) const override { return ::readAttribute(getContext(), reader); } LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const override { return ::writeAttribute(attr, writer); } //===--------------------------------------------------------------------===// // Types Type readType(DialectBytecodeReader &reader) const override { return ::readType(getContext(), reader); } LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const override { return ::writeType(type, writer); } void writeVersion(DialectBytecodeWriter &writer) const final { // TODO: Populate. } std::unique_ptr readVersion(DialectBytecodeReader &reader) const final { // TODO: Populate reader.emitError("Dialect does not support versioning"); return nullptr; } LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const final { return success(); } }; } // namespace //===----------------------------------------------------------------------===// // TOSA control flow support. //===----------------------------------------------------------------------===// /// Returns the while loop body. SmallVector tosa::WhileOp::getLoopRegions() { return {&getBody()}; } //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// void TosaDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" >(); addInterfaces(); declarePromisedInterfaces< mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, GreaterEqualOp, MatMulOp>(); } Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (llvm::isa(type) && llvm::isa(value)) { return builder.create( loc, type, llvm::cast(value)); } if (llvm::isa(value)) return builder.create(loc, type, llvm::cast(value)); return nullptr; } //===----------------------------------------------------------------------===// // Parsers and printers //===----------------------------------------------------------------------===// ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr) { if (succeeded(parser.parseOptionalEqual())) { if (failed(parser.parseAttribute(attr))) { return parser.emitError(parser.getCurrentLocation()) << "expected attribute"; } if (auto typedAttr = dyn_cast(attr)) { typeAttr = TypeAttr::get(typedAttr.getType()); } return success(); } Type type; if (failed(parser.parseColonType(type))) { return parser.emitError(parser.getCurrentLocation()) << "expected type"; } typeAttr = TypeAttr::get(type); return success(); } void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr) { bool needsSpace = false; auto typedAttr = dyn_cast_or_null(attr); if (!typedAttr || typedAttr.getType() != type.getValue()) { p << ": "; p.printAttribute(type); needsSpace = true; // subsequent attr value needs a space separator } if (attr) { if (needsSpace) p << ' '; p << "= "; p.printAttribute(attr); } } //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); RankedTensorType weightType; if constexpr (std::is_same_v) weightType = llvm::dyn_cast(op.getFilter().getType()); else weightType = llvm::dyn_cast(op.getWeight().getType()); // Must be ranked tensor types if (!inputType) { op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); return failure(); } if (!weightType) { if constexpr (std::is_same_v) { op.emitOpError("expect a ranked tensor for filter, got ") << op.getFilter(); } else { op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); } return failure(); } auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); bool inputIsQuant = !llvm::isa(inputEType); bool weightIsQuant = !llvm::isa(weightEType); // Either both must be quantized or both unquantized. if (inputIsQuant != weightIsQuant) { op.emitOpError( "expect both input and weight to be float or not together, got ") << inputEType << " and " << weightEType; return failure(); } // Quantized type must have constructed the quantizationattr, and unquantized // types should not have a quantizationattr. if ((inputIsQuant && !op.getQuantizationInfo()) || (!inputIsQuant && op.getQuantizationInfo())) { op.emitOpError("quantizationattr is required for quantized type, and not " "allowed for float type"); return failure(); } return success(); } LogicalResult tosa::ConstOp::verify() { auto attrType = llvm::dyn_cast(getValueAttr().getType()); auto outputType = llvm::dyn_cast(getOutput().getType()); if (!attrType || !outputType) { emitOpError("expected tensors for attr/result type"); return failure(); } if (auto result = llvm::dyn_cast( outputType.getElementType())) { if (result.getStorageType() == attrType.getElementType()) return success(); } if (attrType.getElementType() != outputType.getElementType()) { emitOpError("expected same attr/result element types"); return failure(); } return success(); } template static LogicalResult verifyConvOpModes(T op) { auto inputEType = llvm::cast(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputEType)) inputEType = quantType.getStorageType(); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) return op.emitOpError("accumulator type for i8 tensor is not i32"); if (inputEType.isInteger(16) && !accType.isInteger(48)) return op.emitOpError("accumulator type for i16 tensor is not i48"); if (isa(inputEType) && !accType.isF16()) return op.emitOpError("accumulator type for f8 tensor is not f16"); if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); if (inputEType.isBF16() && !accType.isF32()) return op.emitOpError("accumulator type for bf16 tensor is not f32"); if (inputEType.isF32() && !accType.isF32()) return op.emitOpError("accumulator type for f32 tensor is not f32"); return success(); } LogicalResult tosa::ArgMaxOp::verify() { // Ensure output is of 32-bit integer const auto resultETy = llvm::cast(getType()).getElementType(); if (!resultETy.isIntOrIndex()) return emitOpError("result tensor is not of integer type"); // Ensure axis is within the tensor rank const auto inputType = llvm::cast(getInput().getType()); const int64_t axis = getAxisAttr().getInt(); if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) return emitOpError("specified axis is outside the rank of the tensor"); return success(); } LogicalResult tosa::AvgPool2dOp::verify() { auto inputType = llvm::cast(getInput().getType()); auto inputETy = inputType.getElementType(); auto resultETy = llvm::cast(getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) inputETy = quantType.getStorageType(); if (auto quantType = llvm::dyn_cast(resultETy)) resultETy = quantType.getStorageType(); auto accType = getAccType(); if (llvm::isa(inputETy) && !accType.isInteger(32)) return emitOpError("accumulator type for integer tensor is not i32"); if (inputETy.isF16() && !(accType.isF16() || accType.isF32())) return emitOpError("accumulator type for f16 tensor is not f16/f32"); if (inputETy.isBF16() && !accType.isF32()) return emitOpError("accumulator type for bf16 tensor is not f32"); if (inputETy.isF32() && !accType.isF32()) return emitOpError("accumulator type for f32 tensor is not f32"); if ((inputETy.isF32() && resultETy.isF32()) || (inputETy.isF16() && resultETy.isF16()) || (inputETy.isBF16() && resultETy.isBF16()) || (inputETy.isInteger(8) && resultETy.isInteger(8)) || (inputETy.isInteger(16) && resultETy.isInteger(16))) return success(); return emitOpError("input/output element types are incompatible."); } LogicalResult tosa::ClampOp::verify() { mlir::Type inputETy = llvm::cast(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) { inputETy = quantType.getStorageType(); } mlir::Type maxFpType = getMaxFpAttr().getType(); mlir::Type minFpType = getMinFpAttr().getType(); mlir::Type outputETy = llvm::cast(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(outputETy)) { outputETy = quantType.getStorageType(); } unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); // If input datatype is float, check that the two min/max_fp attributes // share the same type and that their type is either the same of the input's // datatype, or a float type whose bitwidth > input datatype bitwidth. if (!inputETy.isInteger(dataTypeBitWidth)) { if (((maxFpType != minFpType) || (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= inputETy.getIntOrFloatBitWidth()))) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); } return success(); } //===----------------------------------------------------------------------===// // TOSA Operator Quantization Builders. //===----------------------------------------------------------------------===// /// This builder is called on all convolution operators except TransposeConv, /// which has specialized output shape semantics. The builder also defines the /// bitwidth of the output given the bit width of the input & weight content. static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType) { result.addOperands({input, weight, bias}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); result.addAttribute("acc_type", accType); auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// Handles tosa.transpose_conv2d which has outpad and output shape /// attributes. static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) { result.addOperands({input, weight, bias}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("out_shape", outputShape); result.addAttribute("acc_type", accType); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// The tosa.fully_connected op has its own builder as it does not have /// strides/dilation/padding. static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias) { result.addOperands({input, weight, bias}); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// The tosa.matmul op is also intended to be generated where a /// fully_connected op must be constructed where the weight is not a constant. /// In this case, the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b) { result.addOperands({a, b}); auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); auto inputType = llvm::dyn_cast(a.getType()); assert(inputType && "Input must be a shaped tensor type!"); auto inputQType = llvm::dyn_cast( inputType.getElementType()); assert(inputQType && "Tensor must have quantized datatype!"); unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); IntegerType accElementType; if (inputBits == 16) accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); auto accType = outputShapedType.clone(accElementType); result.addTypes(accType); } else { result.addTypes(outputType); } } /// Both the tosa.avg_pool2d and unary ops use the same /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it /// has additional parameters not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType) { result.addOperands(input); result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); result.addAttribute("acc_type", accType); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on single-parameter unary operators that have scale /// relationship between their input and output, expressed by the /// UnaryOpQuantizationAttr. static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input) { result.addOperands(input); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on TOSA pad operator that needs to create its own /// OptionalAttr quantization_attr parameter to scale the padding values /// correctly. No pad_const is interpreted as zero-padding. static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings) { result.addOperands({input, paddings}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on TOSA pad operator when an explicit pad_const /// value is passed in. It also optionally constructs quantization_attr. static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst) { result.addOperands({input, paddings, padConst}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } //===----------------------------------------------------------------------===// // TOSA Operator Return Type Inference. //===----------------------------------------------------------------------===// static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector &outShape) { int64_t outRank = 0; for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); if (!shape.hasRank()) { // TODO(jennik): Update function to have better case handling for // invalid operands and for ranked tensors. return failure(); } outRank = std::max(outRank, shape.getRank()); } outShape.resize(outRank, 1); for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); auto rankDiff = outShape.size() - shape.getRank(); for (size_t i = 0, e = shape.getRank(); i < e; ++i) { auto dim1 = outShape[i + rankDiff]; auto dim2 = shape.getDimSize(i); auto resolvedDim = dim1; if (dim1 == 1) { resolvedDim = dim2; } else if (dim2 == 1) { resolvedDim = dim1; } else if (dim1 != dim2) { return failure(); } outShape[i + rankDiff] = resolvedDim; } } return success(); } LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ArgMaxOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); IntegerAttr axis = adaptor.getProperties().axis; int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } SmallVector outShape; outShape.reserve(inputShape.getRank() - 1); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (i == axisVal) continue; outShape.push_back(inputShape.getDimSize(i)); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, RFFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); outputShape[0] = inputShape.getDimSize(0); outputShape[1] = inputShape.getDimSize(1); int64_t inWidth = inputShape.getDimSize(2); // Note that we can support this calculation symbolically // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] if (inWidth != ShapedType::kDynamic) outputShape[2] = inWidth / 2 + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, FFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { inferredReturnShapes.push_back( ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType()))); inferredReturnShapes.push_back( ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType()))); return success(); } LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ConcatOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. const Properties &prop = adaptor.getProperties(); int32_t axis = prop.axis.getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : adaptor.getOperands()) { ShapeAdaptor operandShape(operand.getType()); if (!operandShape.hasRank()) continue; // Copy the Operand's rank. if (!hasRankedInput) outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); // Copy shapes until the dim is non-dynamic. for (int i = 0, s = operandShape.getRank(); i < s; i++) { if (i == axis || operandShape.isDynamicDim(i)) continue; if (outputShape[i] == ShapedType::kDynamic) outputShape[i] = operandShape.getDimSize(i); if (outputShape[i] != operandShape.getDimSize(i)) return emitOptionalError(location, "Cannot concat tensors with different sizes" " on the non-axis dimension ", i); } hasRankedInput = true; } Type inputType = llvm::cast(adaptor.getInput1().getType()[0]).getElementType(); if (!hasRankedInput) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } // Determine the dimension size along the concatenation axis. int64_t concatDimSize = 0; for (auto operand : adaptor.getOperands()) { ShapeAdaptor operandShape(operand.getType()); // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { concatDimSize = ShapedType::kDynamic; break; } concatDimSize += operandShape.getDimSize(axis); } outputShape[axis] = concatDimSize; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } LogicalResult tosa::EqualOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { auto elementType = IntegerType::get(context, /*width=*/1); llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); return success(); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); return success(); } bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; return succeeded(verifyCompatibleShape(l[0], r[0])); } LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, FullyConnectedOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); ShapeAdaptor weightShape(adaptor.getWeight().getType()); ShapeAdaptor biasShape(adaptor.getBias().getType()); // All shapes are dynamic. SmallVector outShape; outShape.resize(2, ShapedType::kDynamic); if (inputShape.hasRank()) { outShape[0] = inputShape.getDimSize(0); } if (weightShape.hasRank()) { outShape[1] = weightShape.getDimSize(0); } if (biasShape.hasRank()) { outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) : outShape[1]; } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, MatMulOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor lhsShape(adaptor.getA().getType()); ShapeAdaptor rhsShape(adaptor.getB().getType()); // All shapes are dynamic. SmallVector outShape; outShape.resize(3, ShapedType::kDynamic); if (lhsShape.hasRank()) { outShape[0] = lhsShape.getDimSize(0); outShape[1] = lhsShape.getDimSize(1); } if (rhsShape.hasRank()) { outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) : outShape[0]; outShape[2] = rhsShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, PadOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); auto paddingRank = cast(adaptor.getPadding().getType()).getRank(); SmallVector outputShape; // If the input rank is unknown, we can infer the output rank using the // padding shape's rank divided by 2. if (!inputShape.hasRank()) { outputShape.resize(paddingRank / 2, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } SmallVector paddingValues; // If the paddings value is not a constant, all dimensions must be dynamic. if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(), paddingValues)) { outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (inputShape.isDynamicDim(i)) { outputShape.push_back(ShapedType::kDynamic); continue; } auto padFront = paddingValues[i * 2]; auto padBack = paddingValues[i * 2 + 1]; if (padFront < 0 || padBack < 0) { // if either padding for dim i is -1, output dim is unknown outputShape.push_back(ShapedType::kDynamic); continue; } outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::PadOp::verify() { RankedTensorType inputType = getInput1().getType(); RankedTensorType outputType = getOutput().getType(); auto paddingRank = cast(getPadding().getType()).getRank(); if (inputType.getRank() != outputType.getRank()) return emitOpError() << "expect same input and output tensor rank."; if (paddingRank != inputType.getRank() * 2) return emitOpError() << "expected padding tensor dim 0 to have size " << inputType.getRank() * 2 << " (2*rank(shape1)) but got size " << paddingRank; return success(); } static SmallVector convertToMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { return dim == -1 ? ShapedType::kDynamic : dim; })); } LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, SliceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); SmallVector start; SmallVector size; if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) || !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) { auto rank = cast(adaptor.getSize().getType()).getRank(); SmallVector fallback(rank, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); return success(); } // if size[i] is -1, all remaining elements in dimension i are included // in the slice, similar to TF. ShapeAdaptor inputShape(adaptor.getInput1().getType()); // initialize outputShape to all unknown SmallVector outputShape(size.size(), ShapedType::kDynamic); if (inputShape.hasRank()) { for (size_t i = 0; i < size.size(); i++) { if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 && (ShapedType::isDynamic(inputShape.getDimSize(i)) || start[i] < inputShape.getDimSize(i))) { // size[i] is not 0 and not < -1, and start[i] is in valid range if (ShapedType::isDynamic(inputShape.getDimSize(i))) { // input shape has unknown dim[i] - only valid if size[i] > 0 if (size[i] > 0) { outputShape[i] = size[i]; } } else { // input shape has known dim[i] if (size[i] == -1) { outputShape[i] = inputShape.getDimSize(i) - start[i]; } else if (start[i] + size[i] <= inputShape.getDimSize(i)) { // start[i] + size[i] is within bound of input shape's dim[i] outputShape[i] = size[i]; } } } } } else { outputShape = convertToMlirShape(size); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::SliceOp::verify() { auto inputType = llvm::dyn_cast(getInput1().getType()); if (!inputType) return success(); auto startShapeRank = llvm::cast(getStart().getType()).getRank(); if (inputType.getRank() != startShapeRank) return emitOpError( "length of start attribute is not equal rank of input shape"); auto sizeShapeRank = llvm::cast(getSize().getType()).getRank(); if (inputType.getRank() != sizeShapeRank) return emitOpError( "length of size attribute is not equal rank of input shape"); return success(); } LogicalResult tosa::MulOp::verify() { auto resElemType = getElementTypeOrSelf(getOutput()); // Verify if the element type among operands and result match tosa // specification. if (auto resIntType = dyn_cast(resElemType)) { IntegerType lhsIntType = cast(getElementTypeOrSelf(getInput1())); IntegerType rhsIntType = cast(getElementTypeOrSelf(getInput2())); if (lhsIntType != rhsIntType) return emitOpError("requires the same element type for all operands"); // Though the spec requires the element type of result to be i32, a more // relaxed way is provided at dialect level for easier cooperating with // other dialects. if (lhsIntType.getWidth() > resIntType.getWidth()) return emitOpError("invalid data type size for operands or result"); } else { // For other supported type, the spec requires requires the same element // type for all operands (excludes `shift` operand) and results. for (int i = 0; i < 2; ++i) { if (getElementTypeOrSelf(getOperand(i)) != resElemType) return emitOpError( "requires the same element type for all operands and results"); } } // Verify the op has same ranks for all main operands (excludes extra operands // such as shift of mul op, so this is the only difference with the built-in // `SameOperandsAndResultRank` trait) and results types, if known. // delegate function that returns true if type is a shaped type with known // rank auto hasRank = [](const Type type) { if (auto shaped_type = dyn_cast(type)) return shaped_type.hasRank(); return false; }; auto rankedOperandTypes = llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank)); auto rankedResultTypes = llvm::make_filter_range(getOperation()->getResultTypes(), hasRank); // If all operands and results are unranked, then no further verification. if (rankedOperandTypes.empty() && rankedResultTypes.empty()) return success(); // delegate function that returns rank of shaped type with known rank auto getRank = [](const Type type) { return cast(type).getRank(); }; auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) : getRank(*rankedResultTypes.begin()); for (size_t i = 0; i < 2; ++i) { if (rank != getRank(rankedOperandTypes[i])) { return emitOpError("operands don't have matching ranks"); } } for (const auto type : rankedResultTypes) { if (rank != getRank(type)) { return emitOpError("result type has different rank than operands"); } } return success(); } LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TableOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } inferredReturnShapes.resize(1); inputShape.getDims(inferredReturnShapes[0]); return success(); } LogicalResult tosa::TableOp::verify() { TensorType inputType = getInput1().getType(); TensorType outputType = getOutput().getType(); if (inputType.hasRank() && outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; auto inputDims = inputType.getShape(); auto outputDims = outputType.getShape(); for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) { int64_t dim = it.index(); auto [inputDim, outputDim] = it.value(); if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) { return emitOpError() << "dim(result, " << dim << ") = " << outputDim << " doesn't match dim(input, " << dim << ") = " << inputDim; } } return success(); } LogicalResult tosa::TileOp::getConstantMultiples(SmallVector &multiples) { // Multiples must be constants. DenseIntElementsAttr multiplesAttr; if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) return failure(); multiples = llvm::to_vector( llvm::map_range(multiplesAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); return success(); } LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { DenseIntElementsAttr multiplesAttr; if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr))) return failure(); SmallVector multiples = llvm::to_vector( llvm::map_range(multiplesAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); ShapeAdaptor inputShape(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(multiples.size(), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } else if (static_cast(inputShape.getRank()) != multiples.size()) return failure(); // Any non dynamic dimension can be multiplied to a known size. outputShape.reserve(multiples.size()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { int64_t dim = inputShape.getDimSize(i); if (dim != ShapedType::kDynamic) dim *= multiples[i]; outputShape.push_back(dim); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::TileOp::verify() { ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); shapeType multiplesType = llvm::cast(getMultiples().getType()); auto multiplesRank = multiplesType.getRank(); if (inputType.hasRank()) { if (inputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' to have rank ") << inputType.getRank() << " but got " << multiplesRank << "."; if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError("expect same input and output tensor rank."); } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' array to have length ") << outputType.getRank() << " but got " << multiplesRank << "."; SmallVector multiples; if (getConstantMultiples(multiples).succeeded() && llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) return emitOpError( "expect element of 'multiples' to be positive integer or -1."); return success(); } bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); } LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ReshapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); llvm::SmallVector newShapeValue = convertToMlirShape(adaptor.getNewShape()); // We cannot infer from the total number of elements so we must take the // shape attribute as exact. if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { inferredReturnShapes.push_back( ShapedTypeComponents(newShapeValue, inputType)); return success(); } // Determine the number of elements covered by the slice of all static // dimensions. This allows us to infer the length of the remaining dynamic // dimension. int64_t numElements = inputShape.getNumElements(); int64_t staticMul = 1; for (auto val : newShapeValue) { if (!ShapedType::isDynamic(val)) { staticMul *= val; } } // Determine the length of the dynamic dimension. for (auto &val : newShapeValue) { if (ShapedType::isDynamic(val)) val = numElements / staticMul; } inferredReturnShapes.push_back( ShapedTypeComponents(newShapeValue, inputType)); return success(); } llvm::LogicalResult tosa::ReshapeOp::verify() { TensorType inputType = getInput1().getType(); RankedTensorType outputType = getType(); if ((int64_t)getNewShape().size() != outputType.getRank()) return emitOpError() << "new shape does not match result rank"; for (auto [newShapeDim, outputShapeDim] : zip(getNewShape(), outputType.getShape())) { if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim) return emitOpError() << "new shape is inconsistent with result shape"; if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1) return emitOpError() << "new shape has invalid tensor dimension size " << newShapeDim; } if (inputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); if (outputType.hasStaticShape()) { int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { return emitOpError() << "cannot reshape " << inputElementsNum << " elements into " << outputElementsNum; } } int64_t newShapeElementsNum = std::accumulate( getNewShape().begin(), getNewShape().end(), 1LL, [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); bool isStaticNewShape = llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; }); if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { return emitOpError() << "cannot reshape " << inputElementsNum << " elements into " << newShapeElementsNum; } } int missingDims = llvm::count(getNewShape(), -1); if (missingDims > 1) return emitOpError() << "expected at most one target dimension to be -1"; return mlir::success(); } LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector &perms) { // Perms must be constants. DenseIntElementsAttr permsAttr; if (!matchPattern(getPerms(), m_Constant(&permsAttr))) return failure(); perms.clear(); for (auto v : permsAttr.getValues()) perms.push_back(v.getSExtValue()); return success(); } LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TransposeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); ShapeAdaptor permsShape(adaptor.getPerms().getType()); // We cannot infer anything from a rank-0 "permutation" tensor. if (permsShape.hasRank() && permsShape.getRank() == 0) return failure(); // If input rank and permutation length is unknown, the output rank is // unknown. if (!inputShape.hasRank() || !permsShape.hasRank() || permsShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } // This would imply the number of permutations does not match the rank of // the input which is illegal. if (permsShape.getDimSize(0) != inputShape.getRank()) { return failure(); } SmallVector outputShape; // Rank-0 means no permutations matter. if (inputShape.getRank() == 0) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Check whether the input dimensions are all the same. bool allTheSame = true; for (int i = 1, s = inputShape.getRank(); i < s; i++) { if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { allTheSame = false; break; } } // If all of the input dimensions are the same we don't care about the // permutation. if (allTheSame) { outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); // If the permuations are a constant we can directly determine the output // shape. DenseIntElementsAttr attr; if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) && attr.getType().getRank() == 1) { ShapeAdaptor permShape = attr; // Constant permutation must be the same length as the input rank. if (inputShape.getRank() != permShape.getRank()) return emitOptionalError(location, "constant permutation must be the same length" " as the input rank"); // Constant permutation values must be within the input rank. for (int i = 0, e = inputShape.getRank(); i < e; i++) { if (inputShape.getRank() <= permShape.getDimSize(i)) return failure(); } outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); } } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::TransposeOp::verify() { TensorType inputType = getInput1().getType(); TensorType permType = getPerms().getType(); TensorType outputType = getOutput().getType(); if (permType.hasRank() && permType.getRank() != 1) return emitOpError() << "expected permutation tensor to be rank 1 but got rank " << permType.getRank(); if (inputType.hasRank() && permType.hasRank()) if (!permType.isDynamicDim(0) && permType.getDimSize(0) != inputType.getRank()) return emitOpError() << "expected permutation tensor dim 0 to have size " << inputType.getRank() << " (input rank) but got size " << permType.getDimSize(0); if (inputType.hasRank() && outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; if (outputType.hasRank() && permType.hasRank()) if (!permType.isDynamicDim(0) && permType.getDimSize(0) != outputType.getRank()) return emitOpError() << "expected permutation tensor dim 0 to have size " << outputType.getRank() << " (output rank) but got size " << permType.getDimSize(0); SmallVector constantPerms; if (succeeded(getConstantPerms(constantPerms))) { // Assert that the permutation tensor has a rank, which means that the // rank has been verified above. assert(permType.hasRank() && "Unexpectedly found permutation tensor without rank"); if (!llvm::all_of(constantPerms, [&constantPerms](int32_t s) { return s >= 0 && static_cast(s) < constantPerms.size(); }) || !isPermutationVector(llvm::to_vector(llvm::map_range( constantPerms, [](int32_t v) -> int64_t { return v; })))) return emitOpError() << "expected valid permutation tensor"; // Verify that the types of the input and output tensors are properly // permuted. if (inputType.hasRank() && outputType.hasRank()) { assert(constantPerms.size() == static_cast(inputType.getRank()) && inputType.getRank() == outputType.getRank()); for (auto i = 0; i < outputType.getRank(); i++) { if (inputType.isDynamicDim(constantPerms[i]) || outputType.isDynamicDim(i)) continue; if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i)) return emitOpError() << "expected output tensor dim " << i << " to match " << "input dim " << constantPerms[i] << " with value of " << inputType.getDimSize(constantPerms[i]); } } } return success(); } LogicalResult TransposeOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { SmallVector transposePerms; if (getConstantPerms(transposePerms).failed()) return failure(); Value input = getInput1(); auto inputType = cast(input.getType()); SmallVector returnedDims(inputType.getRank()); for (auto dim : transposePerms) { int32_t dimInInput = transposePerms[dim]; if (inputType.isDynamicDim(dimInInput)) returnedDims[dim] = builder.create(getLoc(), input, dimInInput) .getResult(); else returnedDims[dim] = builder.getIndexAttr(inputType.getDimSize(dimInInput)); } reifiedReturnShapes.emplace_back(std::move(returnedDims)); return success(); } LogicalResult tosa::GatherOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, GatherOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); ShapeAdaptor valuesShape(adaptor.getValues().getType()); if (valuesShape.hasRank()) { outputShape[0] = valuesShape.getDimSize(0); outputShape[2] = valuesShape.getDimSize(2); } ShapeAdaptor indicesShape(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); if (outputShape[1] == ShapedType::kDynamic) outputShape[1] = indicesShape.getDimSize(1); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::ResizeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ResizeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); ShapeAdaptor inputShape(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); outputShape[0] = inputShape.getDimSize(0); outputShape[3] = inputShape.getDimSize(3); int64_t inputHeight = inputShape.getDimSize(1); int64_t inputWidth = inputShape.getDimSize(2); if ((inputHeight == ShapedType::kDynamic) || (inputWidth == ShapedType::kDynamic)) return failure(); llvm::ArrayRef scaleInt = adaptor.getScale(); llvm::ArrayRef offsetInt = adaptor.getOffset(); llvm::ArrayRef borderInt = adaptor.getBorder(); // Compute the output shape based on attributes: scale, offset, and border. outputShape[1] = (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / scaleInt[1]) + 1; outputShape[2] = (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / scaleInt[3]) + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::ScatterOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ScatterOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType()); if (valuesInShape.hasRank()) { outputShape[0] = valuesInShape.getDimSize(0); outputShape[1] = valuesInShape.getDimSize(1); outputShape[2] = valuesInShape.getDimSize(2); } ShapeAdaptor indicesShape(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); } ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = inputShape.getDimSize(0); if (outputShape[2] == ShapedType::kDynamic) outputShape[2] = inputShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } static LogicalResult ReduceInferReturnTypes( ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { int64_t axisVal = axis.getValue().getSExtValue(); if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } SmallVector outputShape; operandShape.getDims(outputShape); outputShape[axisVal] = 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } #define COMPATIBLE_RETURN_TYPES(OP) \ bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ if (l.size() != r.size() || l.size() != 1) \ return false; \ if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ return false; \ return succeeded(verifyCompatibleShape(l[0], r[0])); \ } #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ OP::Adaptor adaptor, \ SmallVectorImpl &inferredReturnShapes) { \ Type inputType = \ llvm::cast(adaptor.getInput().getType()).getElementType(); \ ShapeAdaptor inputShape(adaptor.getInput().getType()); \ const Properties &prop = adaptor.getProperties(); \ return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ inferredReturnShapes); \ } \ COMPATIBLE_RETURN_TYPES(OP) REDUCE_SHAPE_INFER(tosa::ReduceAllOp) REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) REDUCE_SHAPE_INFER(tosa::ReduceMinOp) REDUCE_SHAPE_INFER(tosa::ReduceProdOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) #undef COMPATIBLE_RETURN_TYPES template static LogicalResult verifyReduceOp(T op) { // All TOSA reduce Ops have input, output and axis. TensorType inputType = op.getInput().getType(); TensorType outputType = op.getOutput().getType(); int32_t reduceAxis = op.getAxis(); if (reduceAxis < 0) { op.emitOpError("reduce axis must not be negative"); return failure(); } if (inputType.hasRank()) { int64_t inputRank = inputType.getRank(); // We allow for a special case where the input/output shape has rank 0 and // axis is also 0. if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) { op.emitOpError("expect input tensor rank (") << inputRank << ") to be larger than reduce axis (" << reduceAxis << ")"; return failure(); } } if (outputType.hasRank()) { int64_t outputRank = outputType.getRank(); if (inputType.hasRank() && outputRank != inputType.getRank()) { op.emitOpError( "expect output tensor rank to be equal to input tensor rank"); return failure(); } if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) { op.emitOpError("expect output tensor rank (") << outputRank << ") to be larger than reduce axis (" << reduceAxis << ")"; return failure(); } // We can only verify the reduced dimension size to be 1 if this is not // the special case of output rank == 0. if (outputRank != 0) { auto outputShape = outputType.getShape(); if (!outputType.isDynamicDim(reduceAxis) && outputShape[reduceAxis] != 1) { op.emitOpError("expect reduced dimension size to be 1, got ") << outputShape[reduceAxis]; return failure(); } } } return success(); } LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); } static LogicalResult NAryInferReturnTypes( const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents()); } else { inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); } return success(); } #define NARY_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ ValueShapeRange operands, DictionaryAttr attributes, \ OpaqueProperties properties, RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return NAryInferReturnTypes(operands, inferredReturnShapes); \ } NARY_SHAPE_INFER(tosa::AbsOp) NARY_SHAPE_INFER(tosa::AddOp) NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) NARY_SHAPE_INFER(tosa::BitwiseAndOp) NARY_SHAPE_INFER(tosa::BitwiseOrOp) NARY_SHAPE_INFER(tosa::BitwiseXorOp) NARY_SHAPE_INFER(tosa::BitwiseNotOp) NARY_SHAPE_INFER(tosa::CastOp) NARY_SHAPE_INFER(tosa::CeilOp) NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) NARY_SHAPE_INFER(tosa::CosOp) NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) NARY_SHAPE_INFER(tosa::GreaterEqualOp) NARY_SHAPE_INFER(tosa::GreaterOp) NARY_SHAPE_INFER(tosa::IdentityOp) NARY_SHAPE_INFER(tosa::IntDivOp) NARY_SHAPE_INFER(tosa::LogOp) NARY_SHAPE_INFER(tosa::LogicalAndOp) NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) NARY_SHAPE_INFER(tosa::LogicalNotOp) NARY_SHAPE_INFER(tosa::LogicalOrOp) NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) NARY_SHAPE_INFER(tosa::LogicalXorOp) NARY_SHAPE_INFER(tosa::MaximumOp) NARY_SHAPE_INFER(tosa::MinimumOp) NARY_SHAPE_INFER(tosa::MulOp) NARY_SHAPE_INFER(tosa::NegateOp) NARY_SHAPE_INFER(tosa::PowOp) NARY_SHAPE_INFER(tosa::ReciprocalOp) NARY_SHAPE_INFER(tosa::RescaleOp) NARY_SHAPE_INFER(tosa::ReverseOp) NARY_SHAPE_INFER(tosa::RsqrtOp) NARY_SHAPE_INFER(tosa::SinOp) NARY_SHAPE_INFER(tosa::SelectOp) NARY_SHAPE_INFER(tosa::SubOp) NARY_SHAPE_INFER(tosa::TanhOp) NARY_SHAPE_INFER(tosa::ErfOp) NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER static LogicalResult poolingInferReturnTypes( ShapeAdaptor inputShape, ArrayRef kernel, ArrayRef stride, ArrayRef pad, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); // We only know the rank if the input type is unranked. if (!inputShape) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Batch and number of channels are identical for pooling layer. outputShape[0] = inputShape.getDimSize(0); outputShape[3] = inputShape.getDimSize(3); int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); if (!ShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; } if (!ShapedType::isDynamic(width)) { int64_t padded = width + pad[2] + pad[3] - kernel[1]; outputShape[2] = padded / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = weightShape.getDimSize(0); weightHeight = weightShape.getDimSize(1); weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef padding = adaptor.getPad(); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv2DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) return failure(); return success(); } LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv3DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t inputDepth = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; int64_t weightDepth = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputDepth = inputShape.getDimSize(1); inputHeight = inputShape.getDimSize(2); inputWidth = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[4] = weightShape.getDimSize(0); weightDepth = weightShape.getDimSize(1); weightHeight = weightShape.getDimSize(2); weightWidth = weightShape.getDimSize(3); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { outputShape[4] = biasShape.getDimSize(0); } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef pad = adaptor.getPad(); if (!ShapedType::isDynamic(inputDepth) && !ShapedType::isDynamic(weightDepth)) { int32_t inputSize = inputDepth + pad[0] + pad[1]; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int32_t inputSize = inputHeight + pad[2] + pad[3]; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int32_t inputSize = inputWidth + pad[4] + pad[5]; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[3] = (unstridedResult - 1) / stride[2] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv3DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) return failure(); return success(); } LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, AvgPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); const Properties &prop = adaptor.getProperties(); return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, MaxPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); const Properties &prop = adaptor.getProperties(); return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, DepthwiseConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t inputChannels = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; int64_t depthChannels = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); inputChannels = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { weightHeight = weightShape.getDimSize(0); weightWidth = weightShape.getDimSize(1); inputChannels = ShapedType::isDynamic(inputChannels) ? weightShape.getDimSize(2) : inputChannels; depthChannels = weightShape.getDimSize(3); } // If both inputChannels and depthChannels are available we can determine // the output channels. if (!ShapedType::isDynamic(inputChannels) && !ShapedType::isDynamic(depthChannels)) { outputShape[3] = inputChannels * depthChannels; } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef padding = adaptor.getPad(); llvm::ArrayRef stride = adaptor.getStride(); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult DepthwiseConv2DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) return failure(); return success(); } LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TransposeConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // outputShape is mutable. llvm::SmallVector outputShape = convertToMlirShape(adaptor.getOutShape()); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = ShapedType::isDynamic(outputShape[0]) ? inputShape.getDimSize(0) : outputShape[0]; inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getFilter().getType()); if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) : outputShape[3]; weightHeight = weightShape.getDimSize(1); weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getInput().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef padding = adaptor.getOutPad(); llvm::ArrayRef stride = adaptor.getStride(); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int64_t calculateSize = (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; outputShape[1] = ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int64_t calculateSize = (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; outputShape[2] = ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult TransposeConv2DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) return failure(); return success(); } LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, IfOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (Region *region : adaptor.getRegions()) { for (auto &block : *region) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); } if (yieldOps.empty()) return failure(); // Get the initial type information for the yield op. llvm::SmallVector resultKnowledge; resultKnowledge.reserve(yieldOps.front().getNumOperands()); for (auto operand : yieldOps.front().getOperands()) { resultKnowledge.push_back( ValueKnowledge::getKnowledgeFromType(operand.getType())); } for (auto yieldOp : yieldOps) { if (resultKnowledge.size() != yieldOp.getNumOperands()) return failure(); for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { int32_t index = it.index(); auto meet = ValueKnowledge::meet( resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType())); if (!meet) continue; resultKnowledge[index] = meet; } } for (const ValueKnowledge &result : resultKnowledge) { inferredReturnShapes.push_back(result.getShapedTypeComponents()); } return success(); } LogicalResult WhileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, WhileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (auto &block : adaptor.getBody()) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); // TOSA's while must have a tosa.yield as its terminator. If not found this // tosa.while is invalid. if (yieldOps.empty()) return failure(); // Get the initial type information from the operand types. llvm::SmallVector resultKnowledge; resultKnowledge.reserve(yieldOps.front().getNumOperands()); for (auto operand : yieldOps.front().getOperands()) { resultKnowledge.push_back( ValueKnowledge::getKnowledgeFromType(operand.getType())); } for (auto yieldOp : yieldOps) { if (resultKnowledge.size() != yieldOp.getNumOperands()) return failure(); for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { int32_t index = it.index(); if (auto meet = ValueKnowledge::meet( resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { resultKnowledge[index] = meet; } } } for (const ValueKnowledge &result : resultKnowledge) { inferredReturnShapes.push_back(result.getShapedTypeComponents()); } return success(); } std::optional> ApplyScaleOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType())) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; // Create a i1 tensor type for the boolean condition. Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); if (parser.parseOperand(cond) || parser.resolveOperand(cond, i1Type, result.operands)) return failure(); // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void IfOp::print(OpAsmPrinter &p) { bool printBlockTerminators = false; p << " " << getCond(); if (!getResults().empty()) { p << " -> (" << getResultTypes() << ")"; // Print yield explicitly if the op defines values. printBlockTerminators = true; } p << ' '; p.printRegion(getThenBranch(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseBranch(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); } p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult ReverseOp::verify() { TensorType inputType = getInput1().getType(); TensorType outputType = getOutput().getType(); int32_t reverseAxis = getAxis(); if (reverseAxis < 0) return emitOpError("expected non-negative reverse axis"); if (inputType.hasRank()) { int64_t inputRank = inputType.getRank(); // We allow for a special case where the input/output shape has rank 0 and // axis is also 0. if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0)) return emitOpError("expect input tensor rank (") << inputRank << ") to be larger than reverse axis (" << reverseAxis << ")"; } if (outputType.hasRank()) { int64_t outputRank = outputType.getRank(); if (inputType.hasRank() && outputRank != inputType.getRank()) return emitOpError( "expect output tensor rank to be equal to input tensor rank"); if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0)) return emitOpError("expect output tensor rank (") << outputRank << ") to be larger than reverse axis (" << reverseAxis << ")"; } return success(); } // parse and print of WhileOp refer to the implementation of SCF dialect. ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs; SmallVector operands; Region *cond = result.addRegion(); Region *body = result.addRegion(); OptionalParseResult listResult = parser.parseOptionalAssignmentList(regionArgs, operands); if (listResult.has_value() && failed(listResult.value())) return failure(); FunctionType functionType; SMLoc typeLoc = parser.getCurrentLocation(); if (failed(parser.parseColonType(functionType))) return failure(); result.addTypes(functionType.getResults()); if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) << "expected as many input types as operands " << "(expected " << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. if (failed(parser.resolveOperands(operands, functionType.getInputs(), parser.getCurrentLocation(), result.operands))) return failure(); // Propagate the types into the region arguments. for (size_t i = 0, e = regionArgs.size(); i != e; ++i) regionArgs[i].type = functionType.getInput(i); return failure(parser.parseRegion(*cond, regionArgs) || parser.parseKeyword("do") || parser.parseRegion(*body) || parser.parseOptionalAttrDictWithKeyword(result.attributes)); } static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix = "") { assert(blocksArgs.size() == initializers.size() && "expected same length of arguments and initializers"); if (initializers.empty()) return; parser << prefix << '('; llvm::interleaveComma( llvm::zip(blocksArgs, initializers), parser, [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); parser << ")"; } void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCond().front().getArguments(), getInputs(), " "); parser << " : "; parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes()); parser << ' '; parser.printRegion(getCond(), /*printEntryBlockArgs=*/false); parser << " do "; parser.printRegion(getBody()); parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } //===----------------------------------------------------------------------===// // TOSA Shape and Shape Operators Helper functions. //===----------------------------------------------------------------------===// bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) { return mlir::isa(t); } LogicalResult mlir::tosa::shapeType::verify(function_ref emitError, int rank) { if (rank < 0) return emitError() << "invalid rank (must be >= 0): " << rank; return success(); } LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { for (auto v : op->getOperands()) { if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) { Operation *definingOp = v.getDefiningOp(); if (!definingOp || !definingOp->hasTrait()) { return op->emitOpError("shape operand is not compile time resolvable"); } } } return success(); } LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { for (auto type : op->getOperandTypes()) { if (!mlir::isa(type)) { return op->emitOpError("must have operands with tosa shape type"); } } for (auto type : op->getResultTypes()) { if (!mlir::isa(type)) { return op->emitOpError("must have result with tosa shape type"); } } return success(); } LogicalResult OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || failed(verifyTosaShapeOperator(op))) return failure(); // delegate function that returns rank of shape type auto getRank = [](const Type type) { return mlir::cast(type).getRank(); }; auto operandTypes = op->getOperandTypes(); auto resultTypes = op->getResultTypes(); auto rank = getRank(*op->getOperandTypes().begin()); for (auto type : operandTypes) { if (getRank(type) != rank) { return op->emitOpError("operands don't have matching ranks"); } } for (auto type : resultTypes) { if (getRank(type) != rank) { return op->emitOpError("result shape has different rank than operands"); } } return success(); } //===----------------------------------------------------------------------===// // TOSA Shape Operators verify functions. //===----------------------------------------------------------------------===// LogicalResult tosa::ConstShapeOp::verify() { // check that number of elements in value attr equal to rank of result shape auto count = getValue().getNumElements(); auto rank = (cast(getResult().getType())).getRank(); if (!(count == rank || (count == 1 && rank == 0))) { return emitOpError("expect number of elements in attribute value (") << count << ") to be equal to the rank (" << rank << ") for the result shape type"; } return success(); } //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" //===----------------------------------------------------------------------===// // TOSA Type Definitions. //===----------------------------------------------------------------------===// #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"