xref: /llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (revision a58e774fba42e13aa00667d644e96b783fc914b4)
1// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics -o -| FileCheck %s
2
3// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
4
5// CHECK-LABEL: @test_abs_scalar
6// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
7func.func @test_abs_scalar(%arg0: tensor<f32>) -> tensor<f32> {
8  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<f32>
9  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins([[ARG0]] : tensor<f32>) outs([[INIT]] : tensor<f32>) {
10  // CHECK:   ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32):
11  // CHECK:   [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32
12  // CHECK:   linalg.yield [[ELEMENT]] : f32
13  // CHECK: } -> tensor<f32>
14  %0 = tosa.abs %arg0 : (tensor<f32>) -> tensor<f32>
15
16  // CHECK: return [[GENERIC]] : tensor<f32>
17	return %0 : tensor<f32>
18}
19
20// -----
21
22// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
23// CHECK-LABEL: @test_abs_1d_cast_static_to_dynamic
24// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
25func.func @test_abs_1d_cast_static_to_dynamic(%arg0: tensor<5xf32>) -> tensor<?xf32> {
26  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<5xf32>
27  // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]] : tensor<5xf32>) outs([[EMPTY]] : tensor<5xf32>) {
28  // CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
29  // CHECK:   [[ABS:%.+]] = math.absf [[IN0]] : f32
30  // CHECK:   linalg.yield [[ABS]] : f32
31  // CHECK: } -> tensor<5xf32>
32  // CHECK: [[CAST_RESULT:%.+]] = tensor.cast [[RESULT]] : tensor<5xf32> to tensor<?xf32>
33  %0 = "tosa.abs"(%arg0) : (tensor<5xf32>) -> tensor<?xf32>
34
35  // CHECK: return [[CAST_RESULT]] : tensor<?xf32>
36  return %0 : tensor<?xf32>
37}
38
39// -----
40
41// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
42// CHECK-LABEL: @test_abs_1d_cast_dynamic_to_static
43// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
44func.func @test_abs_1d_cast_dynamic_to_static(%arg0: tensor<?xf32>) -> tensor<5xf32> {
45  // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
46  // CHECK: %[[DIM_SIZE:.*]] = tensor.dim %[[ARG0]], %[[ZERO]] : tensor<?xf32>
47  // CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM_SIZE]]) : tensor<?xf32>
48  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
49  // CHECK: ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32):
50  // CHECK:   %[[VAL_2:.*]] = math.absf %[[VAL_0]] : f32
51  // CHECK:   linalg.yield %[[VAL_2]] : f32
52  // CHECK: } -> tensor<?xf32>
53  // CHECK: %[[CAST_RESULT:.*]] = tensor.cast %[[RESULT]] : tensor<?xf32> to tensor<5xf32>
54  %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<5xf32>
55
56  // CHECK: return %[[CAST_RESULT]] : tensor<5xf32>
57  return %0 : tensor<5xf32>
58}
59
60// -----
61
62// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
63// CHECK-LABEL: @test_abs_1d_dynamic
64// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
65func.func @test_abs_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
66
67  // CHECK: [[ZERO:%.+]] = arith.constant 0 : index
68  // CHECK: [[DIM:%.+]] = tensor.dim [[ARG0]], [[ZERO]] : tensor<?xf32>
69  // CHECK: [[EMPTY:%.+]] = tensor.empty([[DIM]]) : tensor<?xf32>
70  // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xf32>) outs([[EMPTY]] : tensor<?xf32>) {
71  // CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
72  // CHECK:   [[ABSF:%.+]] = math.absf [[IN0]] : f32
73  // CHECK:   linalg.yield [[ABSF]] : f32
74  // CHECK: } -> tensor<?xf32>
75  %0 = tosa.abs %arg0 : (tensor<?xf32>) -> tensor<?xf32>
76
77  // CHECK: return [[RESULT]] : tensor<?xf32>
78  return %0 : tensor<?xf32>
79}
80
81// -----
82
83// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
84// CHECK-LABEL: @test_add_0d
85// CHECK-SAME: [[ARG0:%[0-9a-zA-Z_]*]]:
86// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
87func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
88
89  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<f32>
90  // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins([[ARG0]], [[ARG1]] : tensor<f32>, tensor<f32>) outs([[EMPTY]] : tensor<f32>) {
91  // CHECK: ^bb0([[IN0:%.+]]: f32, [[IN1:%.+]]: f32, [[OUT0:%.+]]: f32):
92  // CHECK:   [[ADDF:%.+]] = arith.addf [[IN0]], [[IN1]] : f32
93  // CHECK:   linalg.yield [[ADDF]] : f32
94  // CHECK: } -> tensor<f32>
95  %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
96
97
98  // CHECK: return [[RESULT]] : tensor<f32>
99  return %0 : tensor<f32>
100}
101
102// -----
103
104// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
105// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)>
106
107// CHECK-LABEL:   func.func @test_add_2d_broadcast(
108// CHECK-SAME:                                     %[[ARG0:.*]]: tensor<2x1xf32>,
109// CHECK-SAME:                                     %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> {
110// CHECK:           %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
111// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
112// CHECK:           ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
113// CHECK:             %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
114// CHECK:             linalg.yield %[[ADD]] : f32
115// CHECK:           } -> tensor<2x1xf32>
116// CHECK:           return %[[RESULT]] : tensor<2x1xf32>
117// CHECK:         }
118func.func @test_add_2d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x1xf32> {
119  // tosa element-wise operators now require operands of equal ranks
120  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
121  return %0 : tensor<2x1xf32>
122}
123
124// -----
125
126// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
127// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
128// CHECK-LABEL: @test_add_1d_all_dynamic
129// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
130// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
131func.func @test_add_1d_all_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
132
133  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
134  // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?xf32>
135  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
136  // CHECK: %[[ARG0_MAX_DIM:.*]] = arith.maxui %[[ARG0_DIM0]], %[[ARG1_DIM0]] : index
137  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
138  // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?xf32>
139  // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
140  // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?xf32>) {
141  // CHECK:   %[[VAL_2:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor<?xf32>
142  // CHECK:   %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<?xf32>) outs(%[[VAL_2]] : tensor<?xf32>) {
143  // CHECK:   ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
144  // CHECK:     linalg.yield %[[VAL_4]] : f32
145  // CHECK:   } -> tensor<?xf32>
146  // CHECK:   scf.yield %[[VAL_3]] : tensor<?xf32>
147  // CHECK: } else {
148  // CHECK:   scf.yield %[[ARG0]] : tensor<?xf32>
149  // CHECK: }
150  // CHECK: %[[VAL_6:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
151  // CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_6]], %[[CONST1]] : index
152  // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_7]] -> (tensor<?xf32>) {
153  // CHECK:   %[[VAL_8:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor<?xf32>
154  // CHECK:   %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor<?xf32>) outs(%[[VAL_8]] : tensor<?xf32>) {
155  // CHECK:   ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
156  // CHECK:     linalg.yield %[[VAL_10]] : f32
157  // CHECK:   } -> tensor<?xf32>
158  // CHECK:   scf.yield %[[VAL_9]] : tensor<?xf32>
159  // CHECK: } else {
160  // CHECK:   scf.yield %[[ARG1]] : tensor<?xf32>
161  // CHECK: }
162  // CHECK: %[[VAL_12:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor<?xf32>
163  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0_DIM0_BROADCAST]], %[[ARG0_DIM1_BROADCAST]] : tensor<?xf32>, tensor<?xf32>) outs(%[[VAL_12]] : tensor<?xf32>) {
164  // CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32):
165  // CHECK:   %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
166  // CHECK:   linalg.yield %[[VAL_16]] : f32
167  // CHECK: } -> tensor<?xf32>
168  %0 = tosa.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
169
170  // CHECK: return %[[RESULT]] : tensor<?xf32>
171  return %0 : tensor<?xf32>
172}
173
174// -----
175
176// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
177// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
178// CHECK-LABEL: @test_add_1d_broadcast_dynamic_to_static
179// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
180// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
181func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor<?xf32>) -> tensor<5xf32> {
182
183  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
184  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
185  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
186  // CHECK: %[[VAL_0:.*]] = arith.cmpi eq, %[[ARG1_DIM0]], %[[CONST1]] : index
187  // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_0]] -> (tensor<?xf32>) {
188  // CHECK:   %[[VAL_1:.*]] = tensor.empty() : tensor<5xf32>
189  // CHECK:   %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor<?xf32>) outs(%[[VAL_1]] : tensor<5xf32>) {
190  // CHECK:   ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32):
191  // CHECK:     linalg.yield %[[VAL_3]] : f32
192  // CHECK:   } -> tensor<5xf32>
193  // CHECK:   %[[VAL_5:.*]] = tensor.cast %[[VAL_2]] : tensor<5xf32> to tensor<?xf32>
194  // CHECK:   scf.yield %[[VAL_5]] : tensor<?xf32>
195  // CHECK: } else {
196  // CHECK:   scf.yield %[[ARG1]] : tensor<?xf32>
197  // CHECK: }
198  // CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5xf32>
199  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1_DIM0_BROADCAST]] : tensor<5xf32>, tensor<?xf32>) outs(%[[VAL_6]] : tensor<5xf32>) {
200  // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32):
201  // CHECK:   %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
202  // CHECK:   linalg.yield %[[VAL_10]] : f32
203  // CHECK: } -> tensor<5xf32>
204  %0 = tosa.add %arg0, %arg1 : (tensor<5xf32>, tensor<?xf32>) -> tensor<5xf32>
205
206  // CHECK: return %[[RESULT]] : tensor<5xf32>
207  return %0 : tensor<5xf32>
208}
209
210// -----
211
212// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
213// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
214// CHECK-LABEL: @test_add_1d_broadcast_static_to_dynamic
215// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
216// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
217func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
218
219  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
220  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
221  // CHECK: %[[VAL_0:.*]] = tensor.empty(%[[ARG1_DIM0]]) : tensor<?xf32>
222  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<?xf32>) outs(%[[VAL_0]] : tensor<?xf32>) {
223  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
224  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
225  // CHECK:   linalg.yield %[[VAL_4]] : f32
226  // CHECK: } -> tensor<?xf32>
227  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<?xf32>) -> tensor<?xf32>
228
229  // CHECK: return %[[RESULT]] : tensor<?xf32>
230  return %0 : tensor<?xf32>
231}
232
233// -----
234
235// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
236// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
237// CHECK-LABEL: @test_add_1d_broadcast_static_to_static
238// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
239// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
240func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
241
242  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<3xf32>
243  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<3xf32>) {
244  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
245  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
246  // CHECK:   linalg.yield %[[VAL_4]] : f32
247  // CHECK: } -> tensor<3xf32>
248  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
249
250  // CHECK: return %[[RESULT]] : tensor<3xf32>
251  return %0 : tensor<3xf32>
252}
253
254// -----
255
256// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
257// CHECK-LABEL: @test_add_1d_matching_no_broadcast
258// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
259// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
260func.func @test_add_1d_matching_no_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
261
262  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1xf32>
263  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_0]] : tensor<1xf32>) {
264  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
265  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
266  // CHECK:   linalg.yield %[[VAL_4]] : f32
267  // CHECK: } -> tensor<1xf32>
268  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
269
270  // CHECK: return %[[RESULT]] : tensor<1xf32>
271  return %0 : tensor<1xf32>
272}
273
274// -----
275
276// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
277// CHECK-LABEL: @test_add_1d_matching_static
278// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
279// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
280func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
281
282  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<3xf32>
283  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<3xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<3xf32>) {
284  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
285  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
286  // CHECK:   linalg.yield %[[VAL_4]] : f32
287  // CHECK: } -> tensor<3xf32>
288  %0 = tosa.add %arg0, %arg1 : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
289
290  // CHECK: return %[[RESULT]] : tensor<3xf32>
291  return %0 : tensor<3xf32>
292}
293
294// -----
295
296// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (0, d1)>
297// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
298// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, 0)>
299// CHECK-LABEL: @test_add_2d_all_dynamic
300// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
301// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
302func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
303
304  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
305  // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
306  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
307  // CHECK: %[[MAX_DIM0:.*]] = arith.maxui %[[ARG0_DIM0]], %[[ARG1_DIM0]] : index
308  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
309  // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
310  // CHECK: %[[ARG1_DIM1:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
311  // CHECK: %[[MAX_DIM1:.*]] = arith.maxui %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
312
313  // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
314  // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
315  // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
316  // CHECK:   %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
317  // CHECK:   %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
318  // CHECK:   %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
319  // CHECK:   %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
320  // CHECK:   ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
321  // CHECK:     linalg.yield %[[VAL_5]] : f32
322  // CHECK:   } -> tensor<?x?xf32>
323  // CHECK:   scf.yield %[[VAL_4]] : tensor<?x?xf32>
324  // CHECK: } else {
325  // CHECK:   scf.yield %[[ARG0]] : tensor<?x?xf32>
326  // CHECK: }
327
328  // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
329  // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
330  // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
331  // CHECK:   %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
332  // CHECK:   %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
333  // CHECK:   %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
334  // CHECK:   %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
335  // CHECK:   ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
336  // CHECK:     linalg.yield %[[VAL_12]] : f32
337  // CHECK:   } -> tensor<?x?xf32>
338  // CHECK:   scf.yield %[[VAL_11]] : tensor<?x?xf32>
339  // CHECK: } else {
340  // CHECK:   scf.yield %[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>
341  // CHECK: }
342
343  // CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
344  // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
345  // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
346  // CHECK:   %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
347  // CHECK:   %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
348  // CHECK:   %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
349  // CHECK:   %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
350  // CHECK:   ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
351  // CHECK:     linalg.yield %[[VAL_19]] : f32
352  // CHECK:   } -> tensor<?x?xf32>
353  // CHECK:   scf.yield %[[VAL_18]] : tensor<?x?xf32>
354  // CHECK: } else {
355  // CHECK:   scf.yield %[[ARG1]] : tensor<?x?xf32>
356  // CHECK: }
357
358  // CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
359  // CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
360  // CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
361  // CHECK:   %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
362  // CHECK:   %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
363  // CHECK:   %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
364  // CHECK:   %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_24]] : tensor<?x?xf32>) {
365  // CHECK:   ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
366  // CHECK:     linalg.yield %[[VAL_26]] : f32
367  // CHECK:   } -> tensor<?x?xf32>
368  // CHECK:   scf.yield %[[VAL_25]] : tensor<?x?xf32>
369  // CHECK: } else {
370  // CHECK:   scf.yield %[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>
371  // CHECK: }
372
373  // CHECK: %[[VAL_28:.*]] = tensor.empty(%[[MAX_DIM0]], %[[MAX_DIM1]]) : tensor<?x?xf32>
374  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM1_BROADCAST]], %[[ARG1_DIM1_BROADCAST]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[VAL_28]] : tensor<?x?xf32>) {
375  // CHECK: ^bb0(%[[VAL_29:.*]]: f32, %[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32):
376  // CHECK:   %[[VAL_32:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32
377  // CHECK:   linalg.yield %[[VAL_32]] : f32
378  // CHECK: } -> tensor<?x?xf32>
379  %0 = tosa.add %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
380
381  // CHECK: return %[[RESULT]] : tensor<?x?xf32>
382  return %0 : tensor<?x?xf32>
383}
384
385// -----
386
387// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
388// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
389// CHECK-LABEL: @test_select_2d_one_dynamic
390// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
391// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
392// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]:
393func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {
394
395  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
396  // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<2x?xi1>
397  // CHECK: %[[ARG1_DIM1:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<2x?xf32>
398  // CHECK: %[[VAL_0:.*]] = arith.maxui %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
399  // CHECK: %[[ARG2_DIM1:.*]] = tensor.dim %[[ARG2]], %[[CONST1]] : tensor<2x?xf32>
400  // CHECK: %[[MAX_DIM1:.*]] = arith.maxui %[[VAL_0]], %[[ARG2_DIM1]] : index
401
402  // CHECK: %[[VAL_1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<2x?xi1>
403  // CHECK: %[[VAL_2:.*]] = arith.cmpi eq, %[[VAL_1]], %[[CONST1]] : index
404  // CHECK: %[[ARG0_BROADCAST:.*]] = scf.if %[[VAL_2]] -> (tensor<2x?xi1>) {
405  // CHECK:   %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xi1>
406  // CHECK:   %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x?xi1>) outs(%[[VAL_3]] : tensor<2x?xi1>) {
407  // CHECK:   ^bb0(%[[VAL_5:.*]]: i1, %[[VAL_6:.*]]: i1):
408  // CHECK:     linalg.yield %[[VAL_5]] : i1
409  // CHECK:   } -> tensor<2x?xi1>
410  // CHECK:   scf.yield %[[VAL_4]] : tensor<2x?xi1>
411  // CHECK: } else {
412  // CHECK:   scf.yield %[[ARG0]] : tensor<2x?xi1>
413  // CHECK: }
414
415  // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<2x?xf32>
416  // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
417  // CHECK: %[[ARG1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<2x?xf32>) {
418  // CHECK:   %[[VAL_9:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
419  // CHECK:   %[[VAL_10:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<2x?xf32>) outs(%[[VAL_9]] : tensor<2x?xf32>) {
420  // CHECK:   ^bb0(%[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: f32):
421  // CHECK:     linalg.yield %[[VAL_11]] : f32
422  // CHECK:   } -> tensor<2x?xf32>
423  // CHECK:   scf.yield %[[VAL_10]] : tensor<2x?xf32>
424  // CHECK: } else {
425  // CHECK:   scf.yield %[[ARG1]] : tensor<2x?xf32>
426  // CHECK: }
427
428  // CHECK: %[[VAL_13:.*]] = tensor.dim %[[ARG2]], %[[CONST1]] : tensor<2x?xf32>
429  // CHECK: %[[VAL_14:.*]] = arith.cmpi eq, %[[VAL_13]], %[[CONST1]] : index
430  // CHECK: %[[ARG2_BROADCAST:.*]] = scf.if %[[VAL_14]] -> (tensor<2x?xf32>) {
431  // CHECK:   %[[VAL_15:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
432  // CHECK:   %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG2]] : tensor<2x?xf32>) outs(%[[VAL_15]] : tensor<2x?xf32>) {
433  // CHECK:   ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32):
434  // CHECK:     linalg.yield %[[VAL_17]] : f32
435  // CHECK:   } -> tensor<2x?xf32>
436  // CHECK:   scf.yield %[[VAL_16]] : tensor<2x?xf32>
437  // CHECK: } else {
438  // CHECK:   scf.yield %[[ARG2]] : tensor<2x?xf32>
439  // CHECK: }
440
441  // CHECK: %[[VAL_19:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
442  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_BROADCAST]], %[[ARG1_BROADCAST]], %[[ARG2_BROADCAST]] : tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) outs(%[[VAL_19]] : tensor<2x?xf32>) {
443  // CHECK: ^bb0(%[[VAL_20:.*]]: i1, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32):
444  // CHECK:   %[[VAL_24:.*]] = arith.select %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : f32
445  // CHECK:   linalg.yield %[[VAL_24]] : f32
446  // CHECK: } -> tensor<2x?xf32>
447  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
448
449  // CHECK: return %[[RESULT]] : tensor<2x?xf32>
450  return %0 : tensor<2x?xf32>
451}
452
453// -----
454
455// CHECK-LABEL: @test_simple_f32
456func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
457  // CHECK: linalg.generic
458  // CHECK: tanh
459  %0 = tosa.tanh %arg0 : (tensor<1xf32>) -> tensor<1xf32>
460
461  // CHECK: linalg.generic
462  // CHECK: math.absf
463  %1 = tosa.abs %arg0 : (tensor<1xf32>) -> tensor<1xf32>
464
465  // CHECK: linalg.generic
466  // CHECK: arith.addf
467  %2 = tosa.add %0, %0 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
468
469  // CHECK: linalg.generic
470  // CHECK: arith.subf
471  %3 = tosa.sub %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
472
473  // CHECK: linalg.generic
474  // CHECK: arith.mulf
475  %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
476
477  // CHECK: linalg.generic
478  // CHECK: arith.negf
479  %5 = tosa.negate %0 : (tensor<1xf32>) -> tensor<1xf32>
480
481  // CHECK: linalg.generic
482  // CHECK: pow
483  %6 = tosa.pow %1, %2 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
484
485  // CHECK: linalg.generic
486  // CHECK: rsqrt
487  %7 = tosa.rsqrt %1 : (tensor<1xf32>) -> tensor<1xf32>
488
489  // CHECK: linalg.generic
490  // CHECK: log
491  %8 = tosa.log %arg0 : (tensor<1xf32>) -> tensor<1xf32>
492
493  // CHECK: linalg.generic
494  // CHECK: exp
495  %9 = tosa.exp %arg0 : (tensor<1xf32>) -> tensor<1xf32>
496
497  // CHECK: linalg.generic
498  // CHECK: arith.cmpf
499  %10 = tosa.greater %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
500
501  // CHECK: linalg.generic
502  // CHECK: arith.cmpf
503  %11 = tosa.greater_equal %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
504
505  // CHECK: linalg.generic
506  // CHECK: arith.cmpf
507  %12 = tosa.equal %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
508
509  // CHECK: linalg.generic
510  // CHECK: select
511  %13 = tosa.select %10, %0, %1 : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
512
513  // CHECK: linalg.generic
514  // CHECK: arith.maximumf
515  %14 = tosa.maximum %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
516
517  // CHECK: linalg.generic
518  // CHECK: arith.minimumf
519  %15 = tosa.minimum %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
520
521  // CHECK: linalg.generic
522  // CHECK: ceil
523  %16 = tosa.ceil %0 : (tensor<1xf32>) -> tensor<1xf32>
524
525  // CHECK: linalg.generic
526  // CHECK: floor
527  %17 = tosa.floor %0 : (tensor<1xf32>) -> tensor<1xf32>
528
529  // CHECK: linalg.generic
530  // CHECK: arith.minimumf
531  // CHECK: arith.maximumf
532  %18 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
533
534  // CHECK: linalg.generic
535  // CHECK: arith.negf
536  // CHECK: exp
537  // CHECK: arith.addf
538  // CHECK: arith.divf
539  %19 = tosa.sigmoid %0 : (tensor<1xf32>) -> tensor<1xf32>
540
541  // CHECK: linalg.generic
542  // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f32
543  // CHECK: [[CSTMIN:%.+]] = arith.constant -2.14748365E+9 : f32
544  // CHECK: [[CSTMAXP1:%.+]] = arith.constant 2.14748365E+9 : f32
545  // CHECK: [[CSTMAX:%.+]] = arith.constant 2147483647 : i32
546  // CHECK: [[MAX:%.+]] = arith.maximumf [[ROUND]], [[CSTMIN]] : f32
547  // CHECK: [[CONV:%.+]] = arith.fptosi [[MAX]] : f32 to i32
548  // CHECK: [[CMP:%.+]] = arith.cmpf uge, [[ROUND]], [[CSTMAXP1]] : f32
549  // CHECK: arith.select [[CMP]], [[CSTMAX]], [[CONV]] : i32
550  %20 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi32>
551
552  // CHECK: linalg.generic
553  // CHECK: arith.constant 0
554  // CHECK: arith.cmpf
555  %21 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xi1>
556
557  // CHECK: linalg.generic
558  // CHECK: arith.truncf
559  %22 = tosa.cast %0 : (tensor<1xf32>) -> tensor<1xf16>
560
561  // CHECK: linalg.generic
562  // CHECK: arith.divf
563  %23 = tosa.reciprocal %0 : (tensor<1xf32>) -> tensor<1xf32>
564
565  // CHECK: linalg.generic
566  // CHECK: math.erf
567  %24 = tosa.erf %0 : (tensor<1xf32>) -> tensor<1xf32>
568
569  // CHECK: linalg.generic
570  // CHECK: math.sin
571  %25 = tosa.sin %arg0 : (tensor<1xf32>) -> tensor<1xf32>
572
573  // CHECK: linalg.generic
574  // CHECK: math.cos
575  %26 = tosa.cos %arg0 : (tensor<1xf32>) -> tensor<1xf32>
576
577  return
578}
579
580// -----
581
582// CHECK-LABEL: @test_simple_f16
583func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
584
585  // CHECK: linalg.generic
586  // CHECK: arith.extf
587  %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32>
588
589  // CHECK: linalg.generic
590  // CHECK: [[ROUND:%.+]] = math.roundeven {{%.+}} : f16
591  // CHECK: [[CSTMIN:%.+]] = arith.constant -1.280000e+02 : f16
592  // CHECK: [[CSTMAX:%.+]] = arith.constant 1.270000e+02 : f16
593  // CHECK: [[MIN:%.+]] = arith.minimumf [[ROUND]], [[CSTMAX]] : f16
594  // CHECK: [[CLAMP:%.+]] = arith.maximumf [[MIN]], [[CSTMIN]] : f16
595  // CHECK: arith.fptosi [[CLAMP]] : f16 to i8
596  %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8>
597
598  // CHECK: linalg.generic
599  // CHECK: [[ROUND:%.+]] = math.roundeven {{%[a-z0-9_]+}} : f16
600  // CHECK: [[CONV:%.+]] = arith.fptosi [[ROUND]] : f16 to i32
601  // CHECK: [[POSINF:%.+]] = arith.constant 0x7C00 : f16
602  // CHECK: [[NEGINF:%.+]] = arith.constant 0xFC00 : f16
603  // CHECK: [[OVERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[POSINF]] : f16
604  // CHECK: [[UNDERFLOW:%.+]] = arith.cmpf ueq, [[ROUND]], [[NEGINF]] : f16
605  // CHECK: [[MININT:%.+]] = arith.constant -2147483648 : i32
606  // CHECK: [[MAXINT:%.+]] = arith.constant 2147483647 : i32
607  // CHECK: [[CLAMPPOSINF:%.+]] = arith.select [[OVERFLOW]], [[MAXINT]], [[CONV]] : i32
608  // CHECK: arith.select [[UNDERFLOW]], [[MININT]], [[CLAMPPOSINF]] : i32
609  %2 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>
610  return
611}
612
613// -----
614
615// CHECK-LABEL: @test_simple_i16
616func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
617  // CHECK: linalg.generic
618  // CHECK: arith.extsi
619  // CHECK: arith.extsi
620  // CHECK: arith.muli
621  %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
622
623  return
624}
625
626// -----
627
628// CHECK-LABEL: @test_simple_ui8
629func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
630  // CHECK: arith.uitofp
631  %0 = tosa.cast %arg0 : (tensor<1xui8>) -> tensor<1xf32>
632  return
633}
634
635// -----
636
637// CHECK-LABEL: @test_simple_i32
638func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
639  // CHECK: linalg.generic
640  // CHECK: arith.addi
641  %0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
642
643  // CHECK: linalg.generic
644  // CHECK: arith.subi
645  %1 = tosa.sub %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
646
647  // CHECK: linalg.generic
648  // CHECK: arith.muli
649  %shift1 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
650  %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi32>
651
652  // CHECK: linalg.generic
653  // CHECK: arith.constant 2
654  // CHECK: apply_scale
655  %shift2 = "tosa.const"() <{value = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
656  %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi32>
657
658  // CHECK: linalg.generic
659  // CHECK: arith.divsi
660  %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
661
662  // CHECK: linalg.generic
663  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
664  // CHECK: [[ZERO:%.+]] = arith.constant 0
665  // CHECK: arith.subi [[ZERO]], %[[ARG1]]
666  %5 = tosa.negate %arg0 : (tensor<1xi32>) -> tensor<1xi32>
667
668  // CHECK: linalg.generic
669  // CHECK: and
670  %6 = tosa.bitwise_and %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
671
672  // CHECK: linalg.generic
673  // CHECK: or
674  %7 = tosa.bitwise_or %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
675
676  // CHECK: linalg.generic
677  // CHECK: arith.xori
678  %8 = tosa.bitwise_xor %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
679
680  // CHECK: linalg.generic
681  // CHECK: arith.shli
682  %9 = tosa.logical_left_shift %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
683
684  // CHECK: linalg.generic
685  // CHECK: arith.shrui
686  %10 = tosa.logical_right_shift %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
687
688  // CHECK: linalg.generic
689  // CHECK: arith.shrsi
690  %11 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
691
692  // CHECK: linalg.generic
693  // CHECK: arith.constant 1
694  // CHECK: arith.constant 0
695  // CHECK: arith.constant true
696  // CHECK: arith.cmpi
697  // CHECK: arith.subi
698  // CHECK: arith.shrsi
699  // CHECK: arith.trunci
700  // CHECK: and
701  // CHECK: and
702  // CHECK: arith.extui
703  // CHECK: arith.addi
704  %12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
705
706  // CHECK: math.ctlz
707  %13 = tosa.clz %arg0 : (tensor<1xi32>) -> tensor<1xi32>
708
709  // CHECK: linalg.generic
710  // CHECK: arith.cmpi
711  %14 = tosa.greater %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
712
713  // CHECK: linalg.generic
714  // CHECK: arith.cmpi
715  %15 = tosa.greater_equal %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
716
717  // CHECK: linalg.generic
718  // CHECK: select
719  %16 = tosa.select %14, %0, %1 : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
720
721  // CHECK: linalg.generic
722  // CHECK: arith.maxsi
723  %17 = tosa.maximum %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
724
725  // CHECK: linalg.generic
726  // CHECK: arith.minsi
727  %18 = tosa.minimum %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
728
729  // CHECK: linalg.generic
730  // CHECK-DAG: arith.maxsi
731  // CHECK-DAG: arith.minsi
732  %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
733
734  // CHECK: linalg.generic
735  // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
736  // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
737  // CHECK-DAG: arith.maxui %[[LB]],
738  // CHECK-DAG: arith.minui %[[UB]],
739  %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
740
741  // CHECK: linalg.generic
742  // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
743  // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
744  // CHECK-DAG: arith.maxui %[[LB]],
745  // CHECK-DAG: arith.minui %[[UB]],
746  %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
747
748  // CHECK: linalg.generic
749  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
750  // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
751  // CHECK-DAG: arith.maxui %[[LB]],
752  // CHECK-DAG: arith.minui %[[UB]],
753  %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
754
755  // CHECK: linalg.generic
756  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
757  // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
758  // CHECK-DAG: arith.maxui %[[LB]],
759  // CHECK-DAG: arith.minui %[[UB]],
760  %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
761
762  // CHECK: linalg.generic
763  // CHECK: arith.trunci
764  %20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>
765
766  // CHECK: linalg.generic
767  // CHECK: arith.extsi
768  %21 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi64>
769
770  // CHECK: linalg.generic
771  // CHECK: arith.constant 0
772  // CHECK: arith.cmpi
773  %22 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi1>
774
775  // CHECK: linalg.generic
776  // CHECK: arith.sitofp
777  %23 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xf32>
778
779  // CHECK: linalg.generic
780  // CHECK: arith.constant 0
781  // CHECK: arith.subi
782  // CHECK: arith.maxsi
783  %24 = tosa.abs %arg0 : (tensor<1xi32>) -> tensor<1xi32>
784
785  return
786}
787
788// -----
789
790// CHECK-LABEL: @test_simple_ui8
791func.func @test_simple_ui8(%arg0: tensor<1xi8>) -> () {
792
793  // CHECK: linalg.generic
794  // CHECK: sitofp
795  %0 = tosa.cast %arg0 : (tensor<1xi8>) -> tensor<1xf32>
796
797  return
798}
799
800// -----
801
802// CHECK-LABEL: @test_i8
803func.func @test_i8(%arg0: tensor<1xi8>) -> () {
804  // CHECK: linalg.generic
805  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
806  // CHECK-DAG: %[[C127:.+]] = arith.constant -127
807  // CHECK-DAG: %[[C126:.+]] = arith.constant 126
808  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
809  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
810  %0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
811
812  // CHECK: linalg.generic
813  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
814  // CHECK-DAG: %[[C128:.+]] = arith.constant -128
815  // CHECK-DAG: %[[C127:.+]] = arith.constant 127
816  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]]
817  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]]
818  %1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
819
820  return
821}
822
823// -----
824
825// CHECK-LABEL: @test_i64
826func.func @test_i64(%arg0: tensor<1xi64>) -> () {
827  // CHECK: linalg.generic
828  // CHECK: ^bb0(%[[ARG1:.+]]: i64,
829  // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
830  // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
831  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
832  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
833  %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
834
835  return
836}
837
838// -----
839
840// CHECK-LABEL: @test_clamp_f16
841func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
842  // CHECK: linalg.generic
843  // CHECK: ^bb0(%[[ARG1:.+]]: f16,
844  // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
845  // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
846  // CHECK-DAG: %[[MIN:.+]] = arith.minimumf %[[ARG1]], %[[C6]]
847  // CHECK-DAG: %[[MAX:.+]] = arith.maximumf %[[MIN]], %[[C0]]
848  %0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
849
850  return
851}
852
853// -----
854
855// CHECK-LABEL: @test_bool
856func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
857  // CHECK: linalg.generic
858  // CHECK: and
859  %0 = tosa.logical_and %arg0, %arg1 : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
860
861  // CHECK: linalg.generic
862  // CHECK: or
863  %1 = tosa.logical_or %arg0, %arg1 : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
864
865  // CHECK: linalg.generic
866  // CHECK: arith.xori
867  %2 = tosa.logical_xor %arg0, %arg1 : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
868
869  // CHECK: linalg.generic
870  // CHECK: arith.constant true
871  // CHECK: arith.xori
872  %3 = tosa.logical_not %arg0 : (tensor<1xi1>) -> tensor<1xi1>
873
874  return
875}
876
877// -----
878
879// CHECK-LABEL: @test_negate_quantized
880func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
881  // CHECK: linalg.generic
882  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
883  // CHECK: [[CNST:%.+]] = arith.constant 7
884  // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
885  // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
886  // CHECK: [[MIN:%.+]] = arith.constant -128
887  // CHECK: [[MAX:%.+]] = arith.constant 127
888  // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
889  // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
890  // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
891  // CHECK: linalg.yield [[TRUNC]]
892  %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
893
894  // CHECK: linalg.generic
895  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
896  // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
897  %1 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32639, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
898
899  // CHECK: linalg.generic
900  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
901  // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
902  %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
903
904  // CHECK: linalg.generic
905  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
906  // CHECK: [[ZERO:%.+]] = arith.constant 0
907  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
908  // CHECK: linalg.yield [[SUB]]
909  %3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
910
911  return
912}
913
914// -----
915
916// CHECK-LABEL: @test_identity
917// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
918// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>
919func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
920  %0 = tosa.identity %arg0 : (tensor<1xf32>) -> tensor<1xf32>
921  %1 = tosa.identity %arg1 : (tensor<1xi32>) -> tensor<1xi32>
922
923  // CHECK: return %[[ARG0]], %[[ARG1]]
924  return %0, %1 : tensor<1xf32>, tensor<1xi32>
925}
926
927// -----
928
929// CHECK-LABEL: @reduce_float
930// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
931func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
932  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
933  // CHECK: [[CST0:%.+]] = arith.constant 0.0
934  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
935  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
936  // CHECK:  (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) {
937  // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
938  // CHECK:   linalg.yield [[RES]] : f32
939  // CHECK:  }
940  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32>
941  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
942
943  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
944  // CHECK: [[CST0:%.+]] = arith.constant 0.0
945  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
946  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>) dimensions = [1]
947  // CHECK:  (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) {
948  // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
949  // CHECK:   linalg.yield [[RES]] : f32
950  // CHECK:  }
951  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32>
952  %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xf32>) -> tensor<5x1xf32>
953
954  // CHECK: arith.constant 1.0
955  // CHECK: linalg.fill
956  // CHECK: linalg.reduce
957  // CHECK: arith.mulf
958  %2 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
959
960  // CHECK: arith.constant 3.40282347E+38 : f32
961  // CHECK: linalg.fill
962  // CHECK: linalg.reduce
963  // CHECK: arith.minimumf
964  %3 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
965
966  // CHECK: arith.constant -3.40282347E+38 : f32
967  // CHECK: linalg.fill
968  // CHECK: linalg.reduce
969  // CHECK: arith.maximumf
970  %4 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
971  return
972}
973
974// -----
975
976// CHECK-LABEL: @reduce_float_dyn
977// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<?x5x4xf32>
978func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
979  // CHECK: %[[C0:.+]] = arith.constant 0
980  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[C0]]
981  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x4xf32>
982  // CHECK: %[[CST0:.+]] = arith.constant 0.0
983  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT]]
984  // CHECK: %[[REDUCE:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<?x5x4xf32>) outs(%[[FILL]] : tensor<?x4xf32>) dimensions = [1]
985  // CHECK:  (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) {
986  // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
987  // CHECK:   linalg.yield %[[RES]] : f32
988  // CHECK:  }
989  // CHECK: %[[C0_0:.+]] = arith.constant 0 : index
990  // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor<?x4xf32>
991  // CHECK: %[[C1:.+]] = arith.constant 1 : index
992  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [%[[DIM_1]], 1, 4] : tensor<?x4xf32> into tensor<?x1x4xf32>
993  %0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
994  return
995}
996
997// -----
998
999// CHECK-LABEL: @reduce_float_dyn_rank_1
1000// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<?xf32>
1001func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
1002  // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<f32>
1003  // CHECK-DAG: %[[CST0:.+]] = arith.constant 0.0
1004  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT]]
1005  // CHECK: %[[REDUCE:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<?xf32>) outs(%[[FILL]] : tensor<f32>) dimensions = [0]
1006  // CHECK:  (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) {
1007  // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
1008  // CHECK:   linalg.yield %[[RES]] : f32
1009  // CHECK:  }
1010  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] output_shape [1] : tensor<f32> into tensor<1xf32>
1011  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<?xf32>) -> tensor<1xf32>
1012  return
1013}
1014
1015// -----
1016
1017// CHECK-LABEL: @reduce_float_dyn_nonzero_batch
1018// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1019func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
1020  // CHECK: %[[C1:.+]] = arith.constant 1
1021  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[C1]]
1022  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<5x?xf32>
1023  // CHECK: %[[CST1:.+]] = arith.constant 1.0
1024  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT]]
1025  // CHECK: %[[REDUCE:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<5x?x4xf32>) outs(%[[FILL]] : tensor<5x?xf32>) dimensions = [2]
1026  // CHECK:  (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) {
1027  // CHECK:   %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
1028  // CHECK:   linalg.yield %[[RES]] : f32
1029  // CHECK:  }
1030  // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
1031  // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C1_0]] : tensor<5x?xf32>
1032  // CHECK: %[[C1_2:.+]] = arith.constant 1 : index
1033  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [5, %[[DIM_1]], 1] : tensor<5x?xf32> into tensor<5x?x1xf32>
1034  %0 = tosa.reduce_prod %arg0 {axis = 2 : i32} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
1035  return
1036}
1037
1038// -----
1039
1040// CHECK-LABEL: @reduce_float_dyn_multiple
1041// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1042func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
1043  // CHECK: %[[C0:.+]] = arith.constant 0
1044  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1045  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]])
1046  // CHECK: %[[CMIN:.+]] = arith.constant -3.40282347E+38
1047  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]]
1048  // CHECK: %[[REDUCE:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>) dimensions = [1]
1049  // CHECK:  (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) {
1050  // CHECK:   %[[MAX:.+]] = arith.maximumf %[[ARG1]], %[[ARG2]] : f32
1051  // CHECK:   linalg.yield %[[MAX]] : f32
1052  // CHECK:  }
1053  // CHECK: %[[C0_0:.+]] = arith.constant 0 : index
1054  // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor<?xf32>
1055  // CHECK: %[[C1_2:.+]] = arith.constant 1 : index
1056  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] output_shape [%[[DIM_1]], 1] : tensor<?xf32> into tensor<?x1xf32>
1057  %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<?x?xf32>) -> tensor<?x1xf32>
1058  return
1059}
1060
1061// -----
1062
1063// CHECK-LABEL: @reduce_int
1064// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32>
1065func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
1066  // CHECK: [[INIT:%.+]] = tensor.empty()
1067  // CHECK: [[CST0:%.+]] = arith.constant 0
1068  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
1069  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>) dimensions = [0]
1070  // CHECK:  (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) {
1071  // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
1072  // CHECK:   linalg.yield [[RES]] : i32
1073  // CHECK:  }
1074  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32>
1075  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
1076
1077  // CHECK: [[INIT:%.+]] = tensor.empty()
1078  // CHECK: [[CST0:%.+]] = arith.constant 0
1079  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
1080  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>) dimensions = [1]
1081  // CHECK:  (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) {
1082  // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
1083  // CHECK:   linalg.yield [[RES]] : i32
1084  // CHECK:  }
1085  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xi32> into tensor<5x1xi32>
1086  %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x1xi32>
1087
1088  // CHECK: arith.constant 1
1089  // CHECK: linalg.fill
1090  // CHECK: linalg.reduce
1091  // CHECK: arith.muli
1092  %2 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
1093
1094  // CHECK: arith.constant 2147483647 : i32
1095  // CHECK: linalg.fill
1096  // CHECK: linalg.reduce
1097  // CHECK: arith.minsi
1098  %3 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
1099
1100  // CHECK: arith.constant -2147483648 : i32
1101  // CHECK: linalg.fill
1102  // CHECK: linalg.reduce
1103  // CHECK: arith.maxsi
1104  %4 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
1105  return
1106}
1107
1108// -----
1109
1110// CHECK-LABEL: @reduce_bool
1111// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi1>
1112func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
1113  // CHECK: [[INIT:%.+]] = tensor.empty()
1114  // CHECK: [[CST0:%.+]] = arith.constant true
1115  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
1116  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<4xi1>) dimensions = [0]
1117  // CHECK:  (%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1) {
1118  // CHECK:   [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
1119  // CHECK:   linalg.yield [[RES]] : i1
1120  // CHECK:  }
1121  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi1> into tensor<1x4xi1>
1122  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1>
1123
1124  // CHECK: arith.constant false
1125  // CHECK: linalg.fill
1126  // CHECK: linalg.reduce
1127  // CHECK: or
1128  %1 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1>
1129
1130  return
1131}
1132
1133// -----
1134// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1135
1136// CHECK-LABEL: @rescale_i8
1137// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1138func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
1139  // CHECK: [[C0:%.+]] = arith.constant 19689
1140  // CHECK: [[C1:%.+]] = arith.constant 15
1141  // CHECK: [[INIT:%.+]] = tensor.empty()
1142  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
1143  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1144  // CHECK: [[C17:%.+]] = arith.constant 17
1145  // CHECK: [[C22:%.+]] = arith.constant 22
1146  // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
1147  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1148  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
1149  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1150  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1151  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
1152  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1153  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1154  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1155  // CHECK-DAG: linalg.yield [[TRUNC]]
1156  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
1157
1158  // CHECK: return
1159  return
1160}
1161
1162// -----
1163// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1164
1165// CHECK-LABEL: @rescale_i8_unsigned_output
1166// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1167func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
1168  // CHECK: [[C0:%.+]] = arith.constant 19689
1169  // CHECK: [[C1:%.+]] = arith.constant 15
1170  // CHECK: [[INIT:%.+]] = tensor.empty()
1171  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
1172  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1173  // CHECK: [[C17:%.+]] = arith.constant 17
1174  // CHECK: [[C22:%.+]] = arith.constant 22
1175  // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
1176  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1177  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
1178  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1179  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1180  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1181  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1182  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1183  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1184  // CHECK: linalg.yield [[TRUNC]]
1185  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
1186
1187  // CHECK: return
1188  return
1189}
1190
1191// -----
1192
1193// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
1194
1195// CHECK-LABEL: @rescale_i8_dyn_batch
1196// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1197func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
1198  // CHECK: %[[C0:.+]] = arith.constant 0
1199  // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1200  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
1201  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
1202  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
1203
1204  // CHECK: %[[C0:.+]] = arith.constant 0
1205  // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1206  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
1207  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
1208  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
1209
1210  return
1211}
1212
1213// -----
1214
1215// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1216
1217// CHECK-LABEL: @rescale_dyn
1218// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1219func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
1220  // CHECK: %[[C1:.+]] = arith.constant 1
1221  // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
1222  // CHECK: %[[C2:.+]] = arith.constant 2
1223  // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
1224  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
1225  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
1226  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
1227  return
1228}
1229
1230// -----
1231
1232// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1233
1234// CHECK-LABEL: @rescale_i8_unsigned_input
1235// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1236func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
1237  // CHECK: [[C0:%.+]] = arith.constant 19689
1238  // CHECK: [[C1:%.+]] = arith.constant 15
1239  // CHECK: [[INIT:%.+]] = tensor.empty()
1240  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
1241  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1242  // CHECK: [[C17:%.+]] = arith.constant 17
1243  // CHECK: [[C22:%.+]] = arith.constant 22
1244  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
1245  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1246  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
1247  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1248  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1249  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
1250  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1251  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1252  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1253  // CHECK: linalg.yield [[TRUNC]]
1254  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
1255
1256  return
1257}
1258
1259// -----
1260
1261// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1262
1263// CHECK-LABEL: @rescale_per_channel
1264// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1265func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
1266  // CHECK: [[MULTIPLIERS:%.+]] = arith.constant dense<[42, 43, 0]>
1267  // CHECK: [[SHIFTS:%.+]] = arith.constant dense<[14, 15, 0]>
1268  // CHECK: [[INIT:%.+]] = tensor.empty()
1269  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[MULTIPLIERS]], [[SHIFTS]] : tensor<3xi8>, tensor<3xi32>, tensor<3xi8>) outs([[INIT]] : tensor<3xi8>)
1270  // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8):
1271  // CHECK: [[C243:%.+]] = arith.constant 243
1272  // CHECK: [[C252:%.+]] = arith.constant 252
1273
1274  // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
1275  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C243]]
1276  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {double_round = false}
1277  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C252]]
1278  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1279  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
1280  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1281  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1282  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1283  // CHECK-DAG: linalg.yield [[TRUNC]]
1284  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> tensor<3xi8>
1285
1286  // CHECK: return [[GENERIC]]
1287  return %0 : tensor<3xi8>
1288}
1289
1290// -----
1291
1292// CHECK-LABEL: @rescaleDoubleRound
1293func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
1294  // CHECK: linalg.generic
1295  // CHECK: tosa.apply_scale
1296  // CHECK-SAME:  {double_round = true}
1297  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 33>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
1298  return %0 : tensor<2xi8>
1299}
1300
1301// CHECK-LABEL: @rescaleUnnecessaryDoubleRound
1302func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
1303  // CHECK: linalg.generic
1304  // CHECK: tosa.apply_scale
1305  // CHECK-SAME:  {double_round = false}
1306  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
1307  return %0 : tensor<2xi8>
1308}
1309
1310// -----
1311
1312// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
1313
1314// CHECK-LABEL: @reverse
1315// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1316func.func @reverse(%arg0: tensor<5x4xi32>) -> () {
1317  // CHECK: %[[C0:.+]] = arith.constant 0
1318  // CHECK: %[[RDIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1319  // CHECK: %[[INIT:.+]] = tensor.empty()
1320  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>)
1321  // CHECK-DAG:   %[[I0:.+]] = linalg.index 0
1322  // CHECK-DAG:   %[[I1:.+]] = linalg.index 1
1323  // CHECK-DAG:   %[[SUB1:.+]] = arith.constant 1
1324  // CHECK-DAG:   %[[RDIM_MINUS_C1:.+]] = arith.subi %[[RDIM]], %[[SUB1]]
1325  // CHECK-DAG:   %[[READ_DIM:.+]] = arith.subi %[[RDIM_MINUS_C1]], %[[I0]]
1326  // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]], %[[I1]]] : tensor<5x4xi32>
1327  // CHECK:   linalg.yield %[[EXTRACT]]
1328  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<5x4xi32>
1329
1330  // CHECK: %[[C1:.+]] = arith.constant 1
1331  // CHECK: %[[RDIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
1332  // CHECK: %[[INIT:.+]] = tensor.empty()
1333  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>)
1334  // CHECK-DAG:   %[[I0:.+]] = linalg.index 0
1335  // CHECK-DAG:   %[[I1:.+]] = linalg.index 1
1336  // CHECK-DAG:   %[[SUB1:.+]] = arith.constant 1
1337  // CHECK-DAG:   %[[RDIM_MINUS_C1:.+]] = arith.subi %[[RDIM]], %[[SUB1]]
1338  // CHECK-DAG:   %[[READ_DIM:.+]] = arith.subi %[[RDIM_MINUS_C1]], %[[I1]]
1339  // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[I0]], %[[READ_DIM]]] : tensor<5x4xi32>
1340  // CHECK:   linalg.yield %[[EXTRACT]]
1341  %1 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x4xi32>
1342  return
1343}
1344
1345// -----
1346
1347// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1348
1349// CHECK-LABEL: @reverse_dyn
1350// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1351func.func @reverse_dyn(%arg0: tensor<?xi32>) -> () {
1352  // CHECK: %[[C0_1:.+]] = arith.constant 0
1353  // CHECK: %[[D0_1:.+]] = tensor.dim %[[ARG0]], %[[C0_1]]
1354  // CHECK: %[[C0_2:.+]] = arith.constant 0
1355  // CHECK: %[[D0_2:.+]] = tensor.dim %[[ARG0]], %[[C0_2]]
1356  // CHECK: %[[INIT:.+]] = tensor.empty(%[[D0_1]])
1357  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel"]} outs(%[[INIT]] : tensor<?xi32>)
1358  // CHECK-DAG:   %[[I0:.+]] = linalg.index 0
1359  // CHECK-DAG:   %[[SUB1:.+]] = arith.constant 1
1360  // CHECK-DAG:   %[[RDIM_MINUS_C1:.+]] = arith.subi %[[D0_2]], %[[SUB1]]
1361  // CHECK-DAG:   %[[READ_DIM:.+]] = arith.subi %[[RDIM_MINUS_C1]], %[[I0]]
1362  // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]]] : tensor<?xi32>
1363  // CHECK:   linalg.yield %[[EXTRACT]]
1364  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<?xi32>) -> tensor<?xi32>
1365  return
1366}
1367
1368// -----
1369
1370// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
1371// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1372
1373// CHECK-LABEL: @tile
1374// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3xi8>
1375func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
1376  // CHECK: [[INIT:%.+]] = tensor.empty()
1377  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
1378  // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
1379  // CHECK:   linalg.yield %[[ARG1]] : i8
1380  // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 4, 3>}
1381  %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
1382  %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8>
1383
1384  // CHECK: [[INIT:%.+]] = tensor.empty()
1385  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
1386  // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
1387  // CHECK:   linalg.yield %[[ARG1]] : i8
1388  // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 2, 6>}
1389  %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
1390  %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8>
1391
1392  // CHECK: [[INIT:%.+]] = tensor.empty()
1393  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
1394  // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
1395  // CHECK:   linalg.yield %[[ARG1]] : i8
1396  // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 10, 21>}
1397  %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2>
1398  %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>)  -> tensor<10x21xi8>
1399
1400  return
1401}
1402
1403// -----
1404
1405// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
1406// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1407
1408// CHECK-LABEL: @tile_dyn_input
1409// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1410func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
1411  // CHECK: %[[CST0:.+]] = arith.constant 0
1412  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x3xi8>
1413  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]])
1414  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
1415  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
1416  // CHECK:   linalg.yield %[[ARG1]] : i8
1417  // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: -9223372036854775808, 3>}
1418  %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
1419  %0 = tosa.tile %arg0, %cst21: (tensor<?x3xi8>, !tosa.shape<2>)  -> tensor<?x3xi8>
1420
1421  return
1422}
1423
1424// -----
1425
1426// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
1427// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1428
1429// CHECK-LABEL: @tile_dyn_multiples
1430// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1431func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
1432  // CHECK: %[[CST1:.+]] = arith.constant 1
1433  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<2x3xi8>
1434  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]])
1435  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
1436  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
1437  // CHECK:   linalg.yield %[[ARG1]] : i8
1438  // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: 2, -9223372036854775808>}
1439  %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
1440  %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>)  -> tensor<2x?xi8>
1441
1442  return
1443}
1444
1445// -----
1446
1447// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
1448// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
1449// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
1450// CHECK: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
1451// CHECK: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
1452
1453func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
1454  // CHECK: [[IDX_INIT:%.+]] = tensor.empty()
1455  // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32
1456  // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]]
1457  // CHECK: [[VAL_INIT:%.+]] = tensor.empty()
1458  // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648
1459  // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]]
1460  // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%[[ARG0]] : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>)
1461  // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i32, %[[ARG2:[0-9a-zA-Z_]+]]: i32, %[[ARG3:[0-9a-zA-Z_]+]]: i32
1462  // CHECK:   [[IDX:%.+]] = linalg.index 0
1463  // CHECK:   [[CAST:%.+]] = arith.index_cast [[IDX]]
1464  // CHECK:   [[CMP:%.+]] = arith.cmpi sgt, %[[ARG1]], %[[ARG3]]
1465  // CHECK:   [[SELECT_VAL:%.+]] = arith.select [[CMP]], %[[ARG1]], %[[ARG3]]
1466  // CHECK:   [[SELECT_IDX:%.+]] = arith.select [[CMP]], [[CAST]], %[[ARG2]]
1467  // CHECK:   linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
1468  %0 = tosa.argmax %arg0 { axis = 0 : i32} : (tensor<3x2xi32>)  -> tensor<2xi32>
1469
1470  // CHECK: [[IDX_INIT:%.+]] = tensor.empty()
1471  // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32
1472  // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]]
1473  // CHECK: [[VAL_INIT:%.+]] = tensor.empty()
1474  // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648
1475  // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]]
1476  // CHECK: linalg.generic {indexing_maps = [#map, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%[[ARG0]] : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
1477  // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i32, %[[ARG2:[0-9a-zA-Z_]+]]: i32, %[[ARG3:[0-9a-zA-Z_]+]]: i32
1478  // CHECK:   [[IDX:%.+]] = linalg.index 1
1479  // CHECK:   [[CAST:%.+]] = arith.index_cast [[IDX]]
1480  // CHECK:   [[CMP:%.+]] = arith.cmpi sgt, %[[ARG1]], %[[ARG3]]
1481  // CHECK:   [[SELECT_VAL:%.+]] = arith.select [[CMP]], %[[ARG1]], %[[ARG3]]
1482  // CHECK:   [[SELECT_IDX:%.+]] = arith.select [[CMP]], [[CAST]], %[[ARG2]]
1483  // CHECK:   linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
1484  %1 = tosa.argmax %arg0 { axis = 1 : i32} : (tensor<3x2xi32>)  -> tensor<3xi32>
1485
1486  // CHECK: arith.constant -3.40282347E+38 : f32
1487  // CHECK: linalg.index
1488  // CHECK: arith.index_cast
1489  // CHECK: arith.cmpf ogt
1490  // CHECK: select
1491  // CHECK: select
1492  // CHECK: linalg.yield
1493  %2 = tosa.argmax %arg1 { axis = 0 : i32} : (tensor<6xf32>)  -> tensor<i32>
1494
1495  return
1496}
1497
1498// -----
1499
1500// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
1501// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
1502
1503func.func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
1504  // CHECK: %[[CST1:.+]] = arith.constant 1
1505  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[CST1]]
1506  // CHECK: %[[IDX_INIT:.+]] = tensor.empty(%[[DYN]])
1507  // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
1508  // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]]
1509  // CHECK: %[[VAL_INIT:.+]] = tensor.empty(%[[DYN]])
1510  // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
1511  // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]]
1512  // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%[[ARG0]] : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<?xi32>, tensor<?xi32>)
1513  // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i32, %[[ARG2:[0-9a-zA-Z_]+]]: i32, %[[ARG3:[0-9a-zA-Z_]+]]: i32
1514  // CHECK:   %[[IDX:.+]] = linalg.index 0
1515  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]]
1516  // CHECK:   %[[CMP:.+]] = arith.cmpi sgt, %[[ARG1]], %[[ARG3]]
1517  // CHECK:   %[[SELECT_VAL:.+]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG3]]
1518  // CHECK:   %[[SELECT_IDX:.+]] = arith.select %[[CMP]], %[[CAST]], %[[ARG2]]
1519  // CHECK:   linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
1520  %0 = tosa.argmax %arg0 { axis = 0 : i32} : (tensor<3x?xi32>)  -> tensor<?xi32>
1521  return
1522}
1523
1524// -----
1525
1526// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
1527// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
1528
1529func.func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () {
1530  // CHECK: %[[IDX_INIT:.+]] = tensor.empty()
1531  // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
1532  // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]]
1533  // CHECK: %[[VAL_INIT:.+]] = tensor.empty()
1534  // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
1535  // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]]
1536  // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[ARG0]] : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
1537  // CHECK:   %[[IDX:.+]] = linalg.index 1
1538  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]]
1539  // CHECK:   %[[CMP:.+]] = arith.cmpi sgt, %[[ARG1]], %[[ARG3]]
1540  // CHECK:   %[[SELECT_VAL:.+]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG3]]
1541  // CHECK:   %[[SELECT_IDX:.+]] = arith.select %[[CMP]], %[[CAST]], %[[ARG2]]
1542  // CHECK:   linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
1543  %0 = tosa.argmax %arg0 { axis = 1 : i32} : (tensor<3x?xi32>)  -> tensor<3xi32>
1544  return
1545}
1546
1547// -----
1548
1549// CHECK-LABEL: @gather_float
1550// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
1551// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
1552func.func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
1553  // CHECK: %[[INIT:.+]] = tensor.empty()
1554  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xf32>)
1555  // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32)
1556  // CHECK:   %[[IDX0:.+]] = linalg.index 0
1557  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
1558  // CHECK:   %[[IDX2:.+]] = linalg.index 2
1559  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xf32>
1560  // CHECK:   linalg.yield %[[EXTRACT]]
1561  %0 = tosa.gather %arg0, %arg1 : (tensor<2x3x2xf32>, tensor<2x3xi32>)  -> tensor<2x3x2xf32>
1562  return
1563}
1564
1565// -----
1566
1567// CHECK-LABEL: @gather_float_dyn
1568// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
1569// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
1570func.func @gather_float_dyn(%arg0: tensor<?x3x2xf32>, %arg1: tensor<?x3xi32>) -> () {
1571  // CHECK: %[[C0:.+]] = arith.constant 0
1572  // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1573  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]])
1574  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x3xi32>) outs(%[[INIT]] : tensor<?x3x2xf32>)
1575  // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32)
1576  // CHECK:   %[[IDX0:.+]] = linalg.index 0
1577  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
1578  // CHECK:   %[[IDX2:.+]] = linalg.index 2
1579  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x3x2xf32>
1580  // CHECK:   linalg.yield %[[EXTRACT]]
1581  %0 = tosa.gather %arg0, %arg1 : (tensor<?x3x2xf32>, tensor<?x3xi32>)  -> tensor<?x3x2xf32>
1582  return
1583}
1584
1585// -----
1586
1587// CHECK-LABEL: @gather_float_all_dynamic
1588// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
1589// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
1590func.func @gather_float_all_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> () {
1591  // CHECK: %[[C0:.+]] = arith.constant 0
1592  // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1593  // CHECK: %[[C1:.+]] = arith.constant 1
1594  // CHECK: %[[INDEX:.+]] = tensor.dim %[[ARG1]], %[[C1]]
1595  // CHECK: %[[C2:.+]] = arith.constant 2
1596  // CHECK: %[[CHANNEL:.+]] = tensor.dim %[[ARG0]], %[[C2]]
1597  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[INDEX]], %[[CHANNEL]])
1598  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xi32>) outs(%[[INIT]] : tensor<?x?x?xf32>)
1599  // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32)
1600  // CHECK:   %[[IDX0:.+]] = linalg.index 0
1601  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
1602  // CHECK:   %[[IDX2:.+]] = linalg.index 2
1603  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x?x?xf32>
1604  // CHECK:   linalg.yield %[[EXTRACT]]
1605  %0 = tosa.gather %arg0, %arg1 : (tensor<?x?x?xf32>, tensor<?x?xi32>)  -> tensor<?x?x?xf32>
1606  return
1607}
1608
1609// -----
1610
1611// CHECK-LABEL: @gather_int
1612// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
1613// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
1614func.func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () {
1615  // CHECK: %[[INIT:.+]] = tensor.empty()
1616  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xi32>)
1617  // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32)
1618  // CHECK:   %[[IDX0:.+]] = linalg.index 0
1619  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
1620  // CHECK:   %[[IDX2:.+]] = linalg.index 2
1621  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xi32>
1622  // CHECK:   linalg.yield %[[EXTRACT]]
1623  %0 = tosa.gather %arg0, %arg1 : (tensor<2x3x2xi32>, tensor<2x3xi32>)  -> tensor<2x3x2xi32>
1624  return
1625}
1626
1627// -----
1628
1629// CHECK-LABEL: @table8
1630// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1631// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
1632func.func @table8(%arg0: tensor<6xi8>, %arg1: tensor<512xi8>) -> () {
1633  // CHECK: %[[INIT:.+]] = tensor.empty()
1634  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
1635  // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
1636  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
1637  // CHECK:   %[[OFFSET:.+]] = arith.constant 128
1638  // CHECK:   %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
1639  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[ADD]]]
1640  // CHECK:   linalg.yield %[[EXTRACT]]
1641  %0 = tosa.table %arg0, %arg1 : (tensor<6xi8>, tensor<512xi8>)  -> tensor<6xi8>
1642  return
1643}
1644
1645// -----
1646
1647// CHECK-LABEL: @table16
1648// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1649// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
1650func.func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
1651  // CHECK: %[[INIT:.+]] = tensor.empty()
1652  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<6xi16>) outs(%[[INIT]] : tensor<6xi32>)
1653  // CHECK: ^bb0(%[[ARG2:.*]]: i16, %[[ARG3:.*]]: i32)
1654  // CHECK: %[[EXT_IN:.+]] = arith.extsi %[[ARG2]]
1655  // CHECK: %[[C32768:.+]] = arith.constant 32768
1656  // CHECK: %[[C7:.+]] = arith.constant 7
1657  // CHECK: %[[C1:.+]] = arith.constant 1
1658  // CHECK: %[[C127:.+]] = arith.constant 127
1659  // CHECK: %[[INADD:.+]] = arith.addi %[[EXT_IN]], %[[C32768]]
1660  // CHECK: %[[IDX:.+]] = arith.shrui %[[INADD]], %[[C7]]
1661  // CHECK: %[[FRACTION:.+]] = arith.andi %[[INADD]], %[[C127]]
1662  // CHECK: %[[IDXPLUS1:.+]] = arith.addi %[[IDX]], %[[C1]]
1663  // CHECK: %[[IDX_CAST:.+]] = arith.index_cast %[[IDX]]
1664  // CHECK: %[[IDXPLUS1_CAST:.+]] = arith.index_cast %[[IDXPLUS1]]
1665  // CHECK: %[[BASE:.+]] = tensor.extract %[[ARG1]][%[[IDX_CAST]]]
1666  // CHECK: %[[NEXT:.+]] = tensor.extract %[[ARG1]][%[[IDXPLUS1_CAST]]]
1667  // CHECK: %[[BASE_EXT:.+]] = arith.extsi %[[BASE]]
1668  // CHECK: %[[NEXT_EXT:.+]] = arith.extsi %[[NEXT]]
1669  // CHECK: %[[BASE_MUL:.+]] = arith.shli %[[BASE_EXT]], %[[C7]]
1670  // CHECK: %[[DIFF:.+]] = arith.subi %[[NEXT_EXT]], %[[BASE_EXT]]
1671  // CHECK: %[[DIFF_MUL:.+]] = arith.muli %[[DIFF]], %[[FRACTION]]
1672  // CHECK: %[[RESULT:.+]] = arith.addi %[[BASE_MUL]], %[[DIFF_MUL]]
1673  // CHECK: linalg.yield %[[RESULT]]
1674  %0 = tosa.table %arg0, %arg1 : (tensor<6xi16>, tensor<513xi16>)  -> tensor<6xi32>
1675  return
1676}
1677
1678// -----
1679
1680// CHECK-LABEL: @table8_dyn
1681// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1682// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
1683func.func @table8_dyn(%arg0: tensor<?xi8>, %arg1: tensor<512xi8>) -> () {
1684  // CHECK: %[[CST0:.+]] = arith.constant 0
1685  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[CST0]]
1686  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]])
1687  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<?xi8>) outs(%[[INIT]] : tensor<?xi8>)
1688  // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
1689  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
1690  // CHECK:   %[[OFFSET:.+]] = arith.constant 128
1691  // CHECK:   %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
1692  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[ADD]]]
1693  // CHECK:   linalg.yield %[[EXTRACT]]
1694  %0 = tosa.table %arg0, %arg1 : (tensor<?xi8>, tensor<512xi8>)  -> tensor<?xi8>
1695  return
1696}
1697
1698// -----
1699
1700// CHECK-LABEL: @table8_dyn_table
1701// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1702// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
1703func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
1704  // CHECK: %[[INIT:.+]] = tensor.empty()
1705  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
1706  // CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
1707  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
1708  // CHECK:   %[[OFFSET:.+]] = arith.constant 128
1709  // CHECK:   %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
1710  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[ADD]]]
1711  // CHECK:   linalg.yield %[[EXTRACT]]
1712  %0 = tosa.table %arg0, %arg1 : (tensor<6xi8>, tensor<?xi8>)  -> tensor<6xi8>
1713  return
1714}
1715
1716// -----
1717// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
1718// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
1719// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
1720
1721// CHECK-LABEL:   func.func @test_static_rfft2d(
1722// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
1723// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
1724// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
1725// CHECK:           %[[VAL_3:.*]] = arith.constant 8 : index
1726// CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
1727// CHECK:           %[[VAL_5:.*]] = arith.constant 5 : index
1728// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
1729// CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
1730// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
1731// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
1732// CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
1733// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
1734// CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
1735// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
1736// CHECK:           %[[VAL_14:.*]] = arith.constant 2 : index
1737// CHECK:           %[[VAL_15:.*]] = arith.constant 8 : index
1738// CHECK:           %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
1739// CHECK:           %[[VAL_17:.*]] = arith.index_castui %[[VAL_13]] : index to i32
1740// CHECK:           %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
1741// CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
1742// CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
1743// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
1744// CHECK:           ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
1745// CHECK:             %[[VAL_25:.*]] = linalg.index 1 : index
1746// CHECK:             %[[VAL_26:.*]] = linalg.index 2 : index
1747// CHECK:             %[[VAL_27:.*]] = linalg.index 3 : index
1748// CHECK:             %[[VAL_28:.*]] = linalg.index 4 : index
1749// CHECK:             %[[VAL_29:.*]] = index.mul %[[VAL_27]], %[[VAL_25]]
1750// CHECK:             %[[VAL_30:.*]] = index.mul %[[VAL_28]], %[[VAL_26]]
1751// CHECK:             %[[VAL_31:.*]] = index.remu %[[VAL_29]], %[[VAL_13]]
1752// CHECK:             %[[VAL_32:.*]] = index.remu %[[VAL_30]], %[[VAL_15]]
1753// CHECK:             %[[VAL_33:.*]] = arith.index_castui %[[VAL_31]] : index to i32
1754// CHECK:             %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : i32 to f32
1755// CHECK:             %[[VAL_35:.*]] = arith.index_castui %[[VAL_32]] : index to i32
1756// CHECK:             %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32
1757// CHECK:             %[[VAL_37:.*]] = arith.divf %[[VAL_34]], %[[VAL_18]] : f32
1758// CHECK:             %[[VAL_38:.*]] = arith.divf %[[VAL_36]], %[[VAL_20]] : f32
1759// CHECK:             %[[VAL_39:.*]] = arith.addf %[[VAL_37]], %[[VAL_38]] : f32
1760// CHECK:             %[[VAL_40:.*]] = arith.mulf %[[VAL_16]], %[[VAL_39]] : f32
1761// CHECK:             %[[VAL_41:.*]] = math.cos %[[VAL_40]] : f32
1762// CHECK:             %[[VAL_42:.*]] = math.sin %[[VAL_40]] : f32
1763// CHECK:             %[[VAL_43:.*]] = arith.mulf %[[VAL_22]], %[[VAL_41]] : f32
1764// CHECK:             %[[VAL_44:.*]] = arith.mulf %[[VAL_22]], %[[VAL_42]] : f32
1765// CHECK:             %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
1766// CHECK:             %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
1767// CHECK:             linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
1768// CHECK:           } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
1769// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
1770// CHECK:         }
1771func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
1772  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
1773  return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
1774}
1775
1776// -----
1777// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
1778// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
1779// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
1780
1781// CHECK-LABEL:   func.func @test_dynamic_rfft2d(
1782// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1783// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : index
1784// CHECK:           %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>
1785// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
1786// CHECK:           %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
1787// CHECK:           %[[VAL_5:.*]] = arith.constant 2 : index
1788// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
1789// CHECK:           %[[VAL_7:.*]] = arith.constant 1 : index
1790// CHECK:           %[[VAL_8:.*]] = arith.constant 2 : index
1791// CHECK:           %[[VAL_9:.*]] = arith.divui %[[VAL_6]], %[[VAL_8]] : index
1792// CHECK:           %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_7]] : index
1793// CHECK:           %[[VAL_11:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
1794// CHECK:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
1795// CHECK:           %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1796// CHECK:           %[[VAL_14:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
1797// CHECK:           %[[VAL_15:.*]] = arith.constant 0.000000e+00 : f32
1798// CHECK:           %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_15]] : f32) outs(%[[VAL_14]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1799// CHECK:           %[[VAL_17:.*]] = arith.constant 1 : index
1800// CHECK:           %[[VAL_18:.*]] = tensor.dim %[[VAL_0]], %[[VAL_17]] : tensor<?x?x?xf32>
1801// CHECK:           %[[VAL_19:.*]] = arith.constant 2 : index
1802// CHECK:           %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?x?xf32>
1803// CHECK:           %[[VAL_21:.*]] = arith.constant 6.28318548 : f32
1804// CHECK:           %[[VAL_22:.*]] = arith.index_castui %[[VAL_18]] : index to i32
1805// CHECK:           %[[VAL_23:.*]] = arith.uitofp %[[VAL_22]] : i32 to f32
1806// CHECK:           %[[VAL_24:.*]] = arith.index_castui %[[VAL_20]] : index to i32
1807// CHECK:           %[[VAL_25:.*]] = arith.uitofp %[[VAL_24]] : i32 to f32
1808// CHECK:           %[[VAL_26:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<?x?x?xf32>) outs(%[[VAL_13]], %[[VAL_16]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1809// CHECK:           ^bb0(%[[VAL_27:.*]]: f32, %[[VAL_28:.*]]: f32, %[[VAL_29:.*]]: f32):
1810// CHECK:             %[[VAL_30:.*]] = linalg.index 1 : index
1811// CHECK:             %[[VAL_31:.*]] = linalg.index 2 : index
1812// CHECK:             %[[VAL_32:.*]] = linalg.index 3 : index
1813// CHECK:             %[[VAL_33:.*]] = linalg.index 4 : index
1814// CHECK:             %[[VAL_34:.*]] = index.mul %[[VAL_32]], %[[VAL_30]]
1815// CHECK:             %[[VAL_35:.*]] = index.mul %[[VAL_33]], %[[VAL_31]]
1816// CHECK:             %[[VAL_36:.*]] = index.remu %[[VAL_34]], %[[VAL_18]]
1817// CHECK:             %[[VAL_37:.*]] = index.remu %[[VAL_35]], %[[VAL_20]]
1818// CHECK:             %[[VAL_38:.*]] = arith.index_castui %[[VAL_36]] : index to i32
1819// CHECK:             %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
1820// CHECK:             %[[VAL_40:.*]] = arith.index_castui %[[VAL_37]] : index to i32
1821// CHECK:             %[[VAL_41:.*]] = arith.uitofp %[[VAL_40]] : i32 to f32
1822// CHECK:             %[[VAL_42:.*]] = arith.divf %[[VAL_39]], %[[VAL_23]] : f32
1823// CHECK:             %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_25]] : f32
1824// CHECK:             %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
1825// CHECK:             %[[VAL_45:.*]] = arith.mulf %[[VAL_21]], %[[VAL_44]] : f32
1826// CHECK:             %[[VAL_46:.*]] = math.cos %[[VAL_45]] : f32
1827// CHECK:             %[[VAL_47:.*]] = math.sin %[[VAL_45]] : f32
1828// CHECK:             %[[VAL_48:.*]] = arith.mulf %[[VAL_27]], %[[VAL_46]] : f32
1829// CHECK:             %[[VAL_49:.*]] = arith.mulf %[[VAL_27]], %[[VAL_47]] : f32
1830// CHECK:             %[[VAL_50:.*]] = arith.addf %[[VAL_28]], %[[VAL_48]] : f32
1831// CHECK:             %[[VAL_51:.*]] = arith.subf %[[VAL_29]], %[[VAL_49]] : f32
1832// CHECK:             linalg.yield %[[VAL_50]], %[[VAL_51]] : f32, f32
1833// CHECK:           } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1834// CHECK:           return %[[VAL_52:.*]]#0, %[[VAL_52]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
1835// CHECK:         }
1836func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1837  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1838  return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
1839}
1840
1841// -----
1842// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
1843// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
1844// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
1845
1846// CHECK-LABEL:   func.func @test_static_fft2d(
1847// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<8x8x8xf32>,
1848// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
1849// CHECK:           %[[VAL_2:.*]] = tensor.empty() : tensor<8x8x8xf32>
1850// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
1851// CHECK:           %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_2]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
1852// CHECK:           %[[VAL_5:.*]] = tensor.empty() : tensor<8x8x8xf32>
1853// CHECK:           %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
1854// CHECK:           %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_5]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
1855// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
1856// CHECK:           %[[VAL_9:.*]] = arith.constant 8 : index
1857// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
1858// CHECK:           %[[VAL_11:.*]] = arith.constant 8 : index
1859// CHECK:           %[[VAL_12:.*]] = arith.constant 6.28318548 : f32
1860// CHECK:           %[[VAL_13:.*]] = arith.index_castui %[[VAL_9]] : index to i32
1861// CHECK:           %[[VAL_14:.*]] = arith.uitofp %[[VAL_13]] : i32 to f32
1862// CHECK:           %[[VAL_15:.*]] = arith.index_castui %[[VAL_11]] : index to i32
1863// CHECK:           %[[VAL_16:.*]] = arith.uitofp %[[VAL_15]] : i32 to f32
1864// CHECK:           %[[VAL_17:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) outs(%[[VAL_4]], %[[VAL_7]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
1865// CHECK:           ^bb0(%[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32):
1866// CHECK:             %[[VAL_22:.*]] = linalg.index 1 : index
1867// CHECK:             %[[VAL_23:.*]] = linalg.index 2 : index
1868// CHECK:             %[[VAL_24:.*]] = linalg.index 3 : index
1869// CHECK:             %[[VAL_25:.*]] = linalg.index 4 : index
1870// CHECK:             %[[VAL_26:.*]] = index.mul %[[VAL_24]], %[[VAL_22]]
1871// CHECK:             %[[VAL_27:.*]] = index.mul %[[VAL_25]], %[[VAL_23]]
1872// CHECK:             %[[VAL_28:.*]] = index.remu %[[VAL_26]], %[[VAL_9]]
1873// CHECK:             %[[VAL_29:.*]] = index.remu %[[VAL_27]], %[[VAL_11]]
1874// CHECK:             %[[VAL_30:.*]] = arith.index_castui %[[VAL_28]] : index to i32
1875// CHECK:             %[[VAL_31:.*]] = arith.uitofp %[[VAL_30]] : i32 to f32
1876// CHECK:             %[[VAL_32:.*]] = arith.index_castui %[[VAL_29]] : index to i32
1877// CHECK:             %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32
1878// CHECK:             %[[VAL_34:.*]] = arith.divf %[[VAL_31]], %[[VAL_14]] : f32
1879// CHECK:             %[[VAL_35:.*]] = arith.divf %[[VAL_33]], %[[VAL_16]] : f32
1880// CHECK:             %[[VAL_36:.*]] = arith.addf %[[VAL_34]], %[[VAL_35]] : f32
1881// CHECK:             %[[VAL_37:.*]] = arith.mulf %[[VAL_12]], %[[VAL_36]] : f32
1882// CHECK:             %[[VAL_38:.*]] = math.cos %[[VAL_37]] : f32
1883// CHECK:             %[[VAL_39:.*]] = math.sin %[[VAL_37]] : f32
1884// CHECK:             %[[VAL_40:.*]] = arith.mulf %[[VAL_18]], %[[VAL_38]] : f32
1885// CHECK:             %[[VAL_41:.*]] = arith.mulf %[[VAL_19]], %[[VAL_39]] : f32
1886// CHECK:             %[[VAL_42:.*]] = arith.addf %[[VAL_40]], %[[VAL_41]] : f32
1887// CHECK:             %[[VAL_43:.*]] = arith.mulf %[[VAL_19]], %[[VAL_38]] : f32
1888// CHECK:             %[[VAL_44:.*]] = arith.mulf %[[VAL_18]], %[[VAL_39]] : f32
1889// CHECK:             %[[VAL_45:.*]] = arith.subf %[[VAL_43]], %[[VAL_44]] : f32
1890// CHECK:             %[[VAL_46:.*]] = arith.addf %[[VAL_20]], %[[VAL_42]] : f32
1891// CHECK:             %[[VAL_47:.*]] = arith.addf %[[VAL_21]], %[[VAL_45]] : f32
1892// CHECK:             linalg.yield %[[VAL_46]], %[[VAL_47]] : f32, f32
1893// CHECK:           } -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>)
1894// CHECK:           return %[[VAL_48:.*]]#0, %[[VAL_48]]#1 : tensor<8x8x8xf32>, tensor<8x8x8xf32>
1895// CHECK:         }
1896func.func @test_static_fft2d(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
1897  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse=false} : (tensor<8x8x8xf32>, tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>)
1898  return %output_real, %output_imag : tensor<8x8x8xf32>, tensor<8x8x8xf32>
1899}
1900
1901// -----
1902// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
1903// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
1904// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
1905
1906// CHECK-LABEL:   func.func @test_dynamic_fft2d(
1907// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<?x?x?xf32>,
1908// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1909// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
1910// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
1911// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
1912// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
1913// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
1914// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
1915// CHECK:           %[[VAL_8:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor<?x?x?xf32>
1916// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
1917// CHECK:           %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_8]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1918// CHECK:           %[[VAL_11:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor<?x?x?xf32>
1919// CHECK:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
1920// CHECK:           %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1921// CHECK:           %[[VAL_14:.*]] = arith.constant 1 : index
1922// CHECK:           %[[VAL_15:.*]] = tensor.dim %[[VAL_0]], %[[VAL_14]] : tensor<?x?x?xf32>
1923// CHECK:           %[[VAL_16:.*]] = arith.constant 2 : index
1924// CHECK:           %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?x?xf32>
1925// CHECK:           %[[VAL_18:.*]] = arith.constant 6.28318548 : f32
1926// CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
1927// CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
1928// CHECK:           %[[VAL_21:.*]] = arith.index_castui %[[VAL_17]] : index to i32
1929// CHECK:           %[[VAL_22:.*]] = arith.uitofp %[[VAL_21]] : i32 to f32
1930// CHECK:           %[[VAL_23:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_10]], %[[VAL_13]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1931// CHECK:           ^bb0(%[[VAL_24:.*]]: f32, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
1932// CHECK:             %[[VAL_28:.*]] = linalg.index 1 : index
1933// CHECK:             %[[VAL_29:.*]] = linalg.index 2 : index
1934// CHECK:             %[[VAL_30:.*]] = linalg.index 3 : index
1935// CHECK:             %[[VAL_31:.*]] = linalg.index 4 : index
1936// CHECK:             %[[VAL_32:.*]] = index.mul %[[VAL_30]], %[[VAL_28]]
1937// CHECK:             %[[VAL_33:.*]] = index.mul %[[VAL_31]], %[[VAL_29]]
1938// CHECK:             %[[VAL_34:.*]] = index.remu %[[VAL_32]], %[[VAL_15]]
1939// CHECK:             %[[VAL_35:.*]] = index.remu %[[VAL_33]], %[[VAL_17]]
1940// CHECK:             %[[VAL_36:.*]] = arith.index_castui %[[VAL_34]] : index to i32
1941// CHECK:             %[[VAL_37:.*]] = arith.uitofp %[[VAL_36]] : i32 to f32
1942// CHECK:             %[[VAL_38:.*]] = arith.index_castui %[[VAL_35]] : index to i32
1943// CHECK:             %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
1944// CHECK:             %[[VAL_40:.*]] = arith.divf %[[VAL_37]], %[[VAL_20]] : f32
1945// CHECK:             %[[VAL_41:.*]] = arith.divf %[[VAL_39]], %[[VAL_22]] : f32
1946// CHECK:             %[[VAL_42:.*]] = arith.addf %[[VAL_40]], %[[VAL_41]] : f32
1947// CHECK:             %[[VAL_43:.*]] = arith.mulf %[[VAL_18]], %[[VAL_42]] : f32
1948// CHECK:             %[[VAL_44:.*]] = arith.constant -1.000000e+00 : f32
1949// CHECK:             %[[VAL_45:.*]] = arith.mulf %[[VAL_43]], %[[VAL_44]] : f32
1950// CHECK:             %[[VAL_46:.*]] = math.cos %[[VAL_45]] : f32
1951// CHECK:             %[[VAL_47:.*]] = math.sin %[[VAL_45]] : f32
1952// CHECK:             %[[VAL_48:.*]] = arith.mulf %[[VAL_24]], %[[VAL_46]] : f32
1953// CHECK:             %[[VAL_49:.*]] = arith.mulf %[[VAL_25]], %[[VAL_47]] : f32
1954// CHECK:             %[[VAL_50:.*]] = arith.addf %[[VAL_48]], %[[VAL_49]] : f32
1955// CHECK:             %[[VAL_51:.*]] = arith.mulf %[[VAL_25]], %[[VAL_46]] : f32
1956// CHECK:             %[[VAL_52:.*]] = arith.mulf %[[VAL_24]], %[[VAL_47]] : f32
1957// CHECK:             %[[VAL_53:.*]] = arith.subf %[[VAL_51]], %[[VAL_52]] : f32
1958// CHECK:             %[[VAL_54:.*]] = arith.addf %[[VAL_26]], %[[VAL_50]] : f32
1959// CHECK:             %[[VAL_55:.*]] = arith.addf %[[VAL_27]], %[[VAL_53]] : f32
1960// CHECK:             linalg.yield %[[VAL_54]], %[[VAL_55]] : f32, f32
1961// CHECK:           } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1962// CHECK:           return %[[VAL_56:.*]]#0, %[[VAL_56]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
1963// CHECK:         }
1964func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1965  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1966  return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
1967}
1968
1969
1970// -----
1971
1972// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
1973
1974// CHECK-LABEL:   func.func @test_cast_fp32_i64(
1975// CHECK-SAME:                                  %[[ARG0:.*]]: tensor<1xf32>) -> tensor<1xi64> {
1976// CHECK:           %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1xi64>
1977// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<1xi64>) {
1978// CHECK:           ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: i64):
1979// CHECK:             %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
1980// CHECK:             %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
1981// CHECK:             %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
1982// CHECK:             %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
1983// CHECK:             %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
1984// CHECK:             %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
1985// CHECK:             %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
1986// CHECK:             %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
1987// CHECK:             linalg.yield %[[SELECT]] : i64
1988// CHECK:           } -> tensor<1xi64>
1989// CHECK:           return %[[RESULT]] : tensor<1xi64>
1990func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
1991  %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
1992  return %0: tensor<1xi64>
1993}
1994