xref: /llvm-project/mlir/test/Dialect/Tosa/ops.mlir (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
3
4
5// -----
6// CHECK-LABEL: argmax
7func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
8  %0 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<14x19xf32>) -> tensor<14xi32>
9  return %0 : tensor<14xi32>
10}
11
12// -----
13// CHECK-LABEL: avg_pool2d_f32
14func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
15  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
16  return %0 : tensor<1x7x7x9xf32>
17}
18
19// -----
20// CHECK-LABEL: avg_pool2d_f16
21func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
22  %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
23  return %0 : tensor<1x7x7x9xf16>
24}
25
26// -----
27// CHECK-LABEL: avg_pool2d_f16_accumf32
28func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
29  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
30  return %0 : tensor<1x7x7x9xf16>
31}
32
33// -----
34// CHECK-LABEL: avg_pool2d_i8
35func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {
36  %0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8>
37  return %0 : tensor<1x7x7x9xi8>
38}
39
40// -----
41// CHECK-LABEL: avg_pool2d_i16
42func.func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> {
43  %0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16>
44  return %0 : tensor<1x7x7x9xi16>
45}
46
47// -----
48// CHECK-LABEL: avg_pool2d_q8
49func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
50  %0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
51  return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
52}
53
54// -----
55// CHECK-LABEL: conv2d
56func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
57  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
58  return %0 : tensor<1x4x4x8xf32>
59}
60
61// -----
62// CHECK-LABEL: conv2d_q8xi4
63func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {
64  %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
65  %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
66  %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
67  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
68  return %3 : tensor<1x1x1x3xi8>
69}
70
71// -----
72// CHECK-LABEL: conv3d
73func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
74  %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
75  return %0 : tensor<1x4x8x21x34xf32>
76}
77
78// -----
79// CHECK-LABEL: conv3d_with_local_bound
80func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
81  %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
82  return %0 : tensor<1x4x8x21x34xf32>
83}
84
85// -----
86// CHECK-LABEL: depthwise_conv2d
87func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
88  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
89  return %0 : tensor<1x4x4x8xf32>
90}
91
92// -----
93// CHECK-LABEL: depthwise_conv2d_with_local_bound
94func.func @test_depthwise_conv2d_with_local_bound(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
95  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
96  return %0 : tensor<1x4x4x8xf32>
97}
98
99// -----
100// CHECK-LABEL: fft2d
101func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
102  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
103  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
104}
105
106// -----
107// CHECK-LABEL: fft2d_with_local_bound
108func.func @test_fft2d_with_local_bound(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
109  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false, local_bound = true} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
110  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
111}
112
113// -----
114// CHECK-LABEL: fully_connected
115func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
116  %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32>
117  return %0 : tensor<14x28xf32>
118}
119
120// -----
121// CHECK-LABEL: test_matmul
122func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
123  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32>
124  return %0 : tensor<1x14x28xf32>
125}
126
127// -----
128// CHECK-LABEL: max_pool2d_f32
129func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
130  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
131  return %0 : tensor<1x32x32x8xf32>
132}
133
134// -----
135// CHECK-LABEL: max_pool2d_bf16
136func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
137  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
138  return %0 : tensor<1x32x32x8xbf16>
139}
140
141// -----
142// CHECK-LABEL: max_pool2d_f16
143func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
144  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
145  return %0 : tensor<1x32x32x8xf16>
146}
147
148// -----
149// CHECK-LABEL: rfft2d
150func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
151  %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
152  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
153}
154
155// -----
156// CHECK-LABEL: rfft2d_with_local_bound
157func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
158  %0, %1 = tosa.rfft2d %arg0 {local_bound = true} : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
159  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
160}
161
162// -----
163// CHECK-LABEL: transpose_conv2d
164func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
165  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
166  return %0 : tensor<1x32x32x16xf32>
167}
168
169// -----
170// CHECK-LABEL: transpose_conv2d_with_local_bound
171func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
172  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
173  return %0 : tensor<1x32x32x16xf32>
174}
175
176// -----
177// CHECK-LABEL: clamp
178func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
179  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
180  return %0 : tensor<13x21x3xf32>
181}
182
183// -----
184// CHECK-LABEL: clamp_propagate
185func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
186  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
187  return %0 : tensor<13x21x3xf32>
188}
189
190// -----
191// CHECK-LABEL: clamp_ignore
192func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
193  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
194  return %0 : tensor<13x21x3xf32>
195}
196
197// -----
198// CHECK-LABEL: clamp_f16
199func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
200  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
201  return %0 : tensor<13x21x3xf16>
202}
203
204// -----
205// CHECK-LABEL: clamp_bf16
206func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
207  %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
208  return %0 : tensor<13x21x3xbf16>
209}
210
211// -----
212// CHECK-LABEL: clamp_quantized
213func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>> {
214  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
215  return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
216}
217
218// -----
219// CHECK-LABEL: sigmoid
220func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
221  %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
222  return %0 : tensor<13x21x3xf32>
223}
224
225// -----
226// CHECK-LABEL: tanh
227func.func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
228  %0 = tosa.tanh %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
229  return %0 : tensor<13x21x3xf32>
230}
231
232// -----
233// CHECK-LABEL: erf
234func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
235  %0 = tosa.erf %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
236  return %0 : tensor<13x21x3xf32>
237}
238
239// -----
240// CHECK-LABEL: add
241func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
242  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
243  return %0 : tensor<13x21x3xf32>
244}
245
246// -----
247// CHECK-LABEL: arithmetic_right_shift
248func.func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
249  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
250  return %0 : tensor<13x21x3xf32>
251}
252
253// -----
254// CHECK-LABEL: bitwise_and
255func.func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
256  %0 = tosa.bitwise_and %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
257  return %0 : tensor<13x21x3xi32>
258}
259
260// -----
261// CHECK-LABEL: bitwise_or
262func.func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
263  %0 = tosa.bitwise_or %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
264  return %0 : tensor<13x21x3xi32>
265}
266
267// -----
268// CHECK-LABEL: bitwise_xor
269func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
270  %0 = tosa.bitwise_xor %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
271  return %0 : tensor<13x21x3xi32>
272}
273
274// -----
275// CHECK-LABEL: int_div
276func.func @test_int_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
277  %0 = tosa.int_div %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
278  return %0 : tensor<13x21x3xi32>
279}
280
281// -----
282// CHECK-LABEL: logical_and
283func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
284  %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
285  return %0 : tensor<13x21x3xi1>
286}
287
288// -----
289// CHECK-LABEL: logical_left_shift
290func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
291  %0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
292  return %0 : tensor<13x21x3xi32>
293}
294
295// -----
296// CHECK-LABEL: logical_right_shift
297func.func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
298  %0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
299  return %0 : tensor<13x21x3xi32>
300}
301
302// -----
303// CHECK-LABEL: logical_or
304func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
305  %0 = tosa.logical_or %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
306  return %0 : tensor<13x21x3xi1>
307}
308
309// -----
310// CHECK-LABEL: logical_xor
311func.func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
312  %0 = tosa.logical_xor %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
313  return %0 : tensor<13x21x3xi1>
314}
315
316// -----
317// CHECK-LABEL: maximum
318func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
319  %0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
320  return %0 : tensor<13x21x3xf32>
321}
322
323// -----
324// CHECK-LABEL: minimum
325func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
326  %0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
327  return %0 : tensor<13x21x3xf32>
328}
329
330// -----
331// CHECK-LABEL: mul
332func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
333  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
334  return %0 : tensor<13x21x3xf32>
335}
336
337// -----
338// CHECK-LABEL: i32_mul
339func.func @test_i32_mul(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
340  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
341  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
342  return %0 : tensor<13x21x3xi32>
343}
344
345// -----
346// CHECK-LABEL: mul
347func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
348  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
349  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi16>, tensor<13x1x3xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
350  return %0 : tensor<13x21x3xi16>
351}
352
353// -----
354// CHECK-LABEL: pow
355func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
356  %0 = tosa.pow %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
357  return %0 : tensor<13x21x3xf32>
358}
359
360// -----
361// CHECK-LABEL: sub
362func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
363  %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
364  return %0 : tensor<13x21x3xf32>
365}
366
367// -----
368// CHECK-LABEL: table
369func.func @main(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform<i16:f32, 1.0:0>>) -> tensor<64x!quant.uniform<i16:f32, 1.0:0>> {
370    %0 = tosa.table %arg0, %arg1 : (tensor<64xi32>, tensor<513x!quant.uniform<i16:f32, 1.000000e+00>>) -> tensor<64x!quant.uniform<i16:f32, 1.000000e+00>>
371    return %0 : tensor<64x!quant.uniform<i16:f32, 1.0:0>>
372}
373
374// -----
375// CHECK-LABEL: abs
376func.func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
377  %0 = tosa.abs %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
378  return %0 : tensor<13x21x3xf32>
379}
380
381// -----
382// CHECK-LABEL: bitwise_not
383func.func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> {
384  %0 = tosa.bitwise_not %arg0 : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32>
385  return %0 : tensor<13x21x1xi32>
386}
387
388// -----
389// CHECK-LABEL: ceil
390func.func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
391  %0 = tosa.ceil %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
392  return %0 : tensor<13x21x3xf32>
393}
394
395// -----
396// CHECK-LABEL: clz
397func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
398  %0 = tosa.clz %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
399  return %0 : tensor<13x21x3xi32>
400}
401
402// -----
403// CHECK-LABEL: cos
404func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
405  %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
406  return %0 : tensor<13x21x3xf32>
407}
408
409// -----
410// CHECK-LABEL: exp
411func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
412  %0 = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
413  return %0 : tensor<13x21x3xf32>
414}
415
416// -----
417// CHECK-LABEL: floor
418func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
419  %0 = tosa.floor %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
420  return %0 : tensor<13x21x3xf32>
421}
422
423// -----
424// CHECK-LABEL: log
425func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
426  %0 = tosa.log %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
427  return %0 : tensor<13x21x3xf32>
428}
429
430// -----
431// CHECK-LABEL: logical_not
432func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
433  %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1>
434  return %0 : tensor<1x21x3xi1>
435}
436
437// -----
438// CHECK-LABEL: negate
439func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
440  %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
441  return %0 : tensor<13x21x3xf32>
442}
443
444// -----
445// CHECK-LABEL: reciprocal
446func.func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
447  %0 = tosa.reciprocal %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
448  return %0 : tensor<13x21x3xf32>
449}
450
451// -----
452// CHECK-LABEL: rsqrt
453func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
454  %0 = tosa.rsqrt %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
455  return %0 : tensor<13x21x3xf32>
456}
457
458// -----
459// CHECK-LABEL: sin
460func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
461  %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
462  return %0 : tensor<13x21x3xf32>
463}
464
465// -----
466// CHECK-LABEL: select
467func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
468  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
469  return %0 : tensor<13x21x3xf32>
470}
471
472
473// -----
474// CHECK-LABEL: equal
475func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> {
476  %0 = tosa.equal %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1>
477  return %0 : tensor<13x21x3xi1>
478}
479
480// -----
481// CHECK-LABEL: greater
482func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
483  %0 = tosa.greater %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
484  return %0 : tensor<13x21x3xi1>
485}
486
487// -----
488// CHECK-LABEL: greater_equal
489func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
490  %0 = tosa.greater_equal %arg0, %arg1 : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
491  return %0 : tensor<13x21x3xi1>
492}
493
494// -----
495// CHECK-LABEL: reduce_all
496func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
497  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
498  %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
499  return %1 : tensor<21x3xi1>
500}
501
502// -----
503// CHECK-LABEL: reduce_any
504func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
505  %0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
506  %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
507  return %1 : tensor<21x3xi1>
508}
509
510// -----
511// CHECK-LABEL: reduce_max
512func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
513  %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
514  %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
515  return %1 : tensor<21x3xf32>
516}
517
518// -----
519// CHECK-LABEL: reduce_min
520func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
521  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
522  %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
523  return %1 : tensor<21x3xf32>
524}
525
526// -----
527// CHECK-LABEL: reduce_product
528func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
529  %0 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
530  %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
531  return %1 : tensor<21x3xf32>
532}
533
534// -----
535// CHECK-LABEL: reduce_sum
536func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
537  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
538  %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
539  return %1 : tensor<21x3xf32>
540}
541
542// -----
543// CHECK-LABEL: concat
544func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
545  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
546  return %0 : tensor<26x21x3xf32>
547}
548
549// -----
550// CHECK-LABEL: pad
551func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
552  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
553  %0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
554  return %0 : tensor<13x21x3xf32>
555}
556
557// -----
558// CHECK-LABEL: pad_explicit_value
559func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
560  %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
561  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
562  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
563  return %1 : tensor<13x21x3xf32>
564}
565
566// -----
567// CHECK-LABEL: reshape
568func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
569  %0 = tosa.reshape %arg0 {new_shape = array<i64: 1, 819>} : (tensor<13x21x3xf32>) -> tensor<1x819xf32>
570  return %0 : tensor<1x819xf32>
571}
572
573// -----
574// CHECK-LABEL: reverse
575func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
576  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
577  return %0 : tensor<13x21x3xf32>
578}
579
580// -----
581// CHECK-LABEL: slice
582func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
583  %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
584  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
585  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32>
586  return %2 : tensor<4x11x1xf32>
587}
588
589// -----
590// CHECK-LABEL: slice_size
591func.func @test_slice_size(%arg0: tensor<13x21x3xf32>) -> tensor<7x11x1xf32> {
592  %0 = tosa.const_shape {value = dense<[-1, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
593  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
594  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x11x1xf32>
595  return %2 : tensor<7x11x1xf32>
596}
597
598// -----
599// CHECK-LABEL: tile
600func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
601  %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
602  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
603  return %0 : tensor<39x21x6xf32>
604}
605
606// -----
607// CHECK-LABEL: transpose
608func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
609  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
610  %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
611  return %1 : tensor<3x13x21xf32>
612}
613
614// -----
615// CHECK-LABEL: transpose_dynamic_dim
616func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> {
617  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
618  %1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
619  return %1 : tensor<3x13x?xf32>
620}
621
622// -----
623// CHECK-LABEL: transpose_half_dynamic_dim
624func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> {
625  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
626  %1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
627  return %1 : tensor<3x13x?xf32>
628}
629
630// -----
631// CHECK-LABEL: gather
632func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {
633  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x3xf32>
634  return %0 : tensor<13x26x3xf32>
635}
636
637// -----
638// CHECK-LABEL: scatter
639func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
640  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
641  return %0 : tensor<13x21x3xf32>
642}
643
644// -----
645// CHECK-LABEL: resize
646func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
647  %1 = tosa.resize %arg0 { scale = array<i64: 4, 2, 4, 2>, offset = array<i64: -1, -1>, border = array<i64: 1, 1>, mode = "BILINEAR" } : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
648  return %1 : tensor<1x64x64x8xf32>
649}
650
651// -----
652// CHECK-LABEL: cast
653func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
654  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
655  return %0 : tensor<13x21x3xf32>
656}
657
658// -----
659// CHECK-LABEL: cast2
660func.func @test_cast2(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<u8:f32, 0.078431375324726104:128>> {
661  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<u8:f32, 0.078431375324726104:128>>
662  return %0 : tensor<13x21x3x!quant.uniform<u8:f32, 0.078431375324726104:128>>
663}
664
665// -----
666// CHECK-LABEL: cast3
667func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
668  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
669  return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
670}
671
672// -----
673// CHECK-LABEL: rescale
674func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
675    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
676    return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
677}
678
679// -----
680// CHECK-LABEL: const
681func.func @test_const(%arg0 : index) -> tensor<4xi32> {
682    %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
683    return %0 : tensor<4xi32>
684}
685
686// -----
687// CHECK-LABEL: identity
688func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
689  %0 = tosa.identity %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
690  return %0 : tensor<13x21x3xi32>
691}
692
693// -----
694// CHECK-LABEL: cond_if
695func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
696  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
697    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
698    tosa.yield %1 : tensor<f32>
699  } else {
700    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
701    tosa.yield %1 : tensor<f32>
702  }
703  return %0 : tensor<f32>
704}
705
706// -----
707// CHECK-LABEL: while_loop
708func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
709  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
710  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
711    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
712    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
713    tosa.yield %3 : tensor<i1>
714  } do {
715  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
716    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
717    %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
718    %4 = tosa.reshape %2 {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
719    %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
720    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
721    tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
722  }
723  return
724}
725
726// -----
727// CHECK-LABEL: custom
728func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
729  %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
730  return %0 : tensor<10xi32>
731}
732
733// -----
734// CHECK-LABEL: const_shape
735func.func @test_const_shape() -> !tosa.shape<4> {
736  %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
737  return %cst : !tosa.shape<4>
738}
739