xref: /llvm-project/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (revision a58e774fba42e13aa00667d644e96b783fc914b4)
1 //===-- TosaOps.h - TOSA dialect operation definitions ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the TOSA Dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_TOSA_IR_TOSAOPS_H
14 #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
15 
16 #include "mlir/Bytecode/BytecodeOpInterface.h"
17 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
18 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.h"
23 #include "mlir/Interfaces/LoopLikeInterface.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "mlir/Interfaces/VectorInterfaces.h"
26 
27 //===----------------------------------------------------------------------===//
28 // TOSA dialect and structs includes.
29 //===----------------------------------------------------------------------===//
30 
31 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
32 
33 namespace mlir {
34 class PatternRewriter;
35 
36 namespace tosa {
37 
38 ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
39                             Attribute &attr);
40 void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
41                      Attribute attr);
42 
43 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
44 
45 } // namespace tosa
46 
47 namespace OpTrait {
48 namespace tosa {
49 
50 // This trait verifies if the element type amoung operands and result
51 // of multiplication match tosa specification.
52 template <typename ConcreteType>
53 class MulOperandsAndResultElementType
54     : public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
55 public:
56   static LogicalResult verifyTrait(Operation *op) {
57     // Check we have a single result.
58     if (failed(impl::verifyOneResult(op)))
59       return failure();
60     Type resElemType = getElementTypeOrSelf(op->getResult(0));
61 
62     // Check we have lhs and rhs.
63     if (failed(impl::verifyAtLeastNOperands(op, 2)))
64       return failure();
65 
66     Type lhsElemType = getElementTypeOrSelf(op->getOperand(0));
67     Type rhsElemType = getElementTypeOrSelf(op->getOperand(1));
68 
69     // Check that for i32 a shift has been explicitly provided.
70     if (lhsElemType.isInteger(32) && failed(impl::verifyNOperands(op, 3)))
71       return failure();
72 
73     // Verify operands type match (ignoring the shift parameter which will
74     // always be i8).
75     if (lhsElemType != rhsElemType)
76       return op->emitOpError("requires the same element type for all operands");
77 
78     // Though the spec requires the element type of result to be i32, a more
79     // relaxed way is provided at dialect level for easier cooperating with
80     // other dialects.
81     if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
82       auto lhsIntType = cast<IntegerType>(lhsElemType);
83       if (lhsIntType.getWidth() > resIntType.getWidth())
84         return op->emitOpError("invalid data type size for operands or result");
85     } else {
86       // In cases of floating point type or quant types, op requires the same
87       // element type for all operands and result (excluding shift).
88       if (resElemType != lhsElemType)
89         return op->emitOpError(
90             "requires the same element type for all operands and results");
91     }
92 
93     return llvm::success();
94   }
95 };
96 
97 /// This class indicates that an op is tosa-elementwise (permits broadcasting,
98 /// unlike Elementwise trait).
99 template <typename ConcreteType>
100 class TosaElementwiseOperator
101     : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
102 
103 LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
104 /// This class verifies that tosa shape operands are compile time resolvable
105 template <typename ConcreteType>
106 class TosaResolvableShapeOperands
107     : public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
108 public:
109   static LogicalResult verifyTrait(Operation *op) {
110     return verifyTosaResolvableShapeOperands(op);
111   }
112 };
113 
114 LogicalResult verifyTosaShapeOperator(Operation *op);
115 /// This class indicates that op operates on tosa shape types
116 template <typename ConcreteType>
117 class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
118 public:
119   static LogicalResult verifyTrait(Operation *op) {
120     return verifyTosaShapeOperator(op);
121   }
122 };
123 
124 LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
125 /// This class indicates that op operates on tosa shape types
126 template <typename ConcreteType>
127 class TosaShapeOperatorWithSameRanks
128     : public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
129 public:
130   static LogicalResult verifyTrait(Operation *op) {
131     return verifyTosaShapeOperatorWithSameRanks(op);
132   }
133 };
134 
135 } // namespace tosa
136 } // namespace OpTrait
137 
138 namespace tosa {
139 
140 bool isa_tosa_shape_type(mlir::Type t);
141 
142 } // namespace tosa
143 
144 } // namespace mlir
145 
146 #define GET_ATTRDEF_CLASSES
147 #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
148 
149 #define GET_TYPEDEF_CLASSES
150 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"
151 
152 #define GET_OP_CLASSES
153 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
154 
155 #endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
156