xref: /llvm-project/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir (revision 0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83)
1// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
2
3// Verifies that different argument types is legal.
4func.func @generalize_matmul_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
5  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
6                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
7  return %0: tensor<16x32xf32>
8}
9
10// CHECK-LABEL: @generalize_matmul_tensor_f16f64f32
11// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
12// Verify floating point extension and truncation.
13// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
14// CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
15// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
16// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
17// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
18// CHECK-NEXT: -> tensor<16x32xf32>
19
20// -----
21
22// Verifies that different argument types is legal.
23func.func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
24  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
25                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
26  return %0: tensor<16x32xi32>
27}
28
29// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32
30// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i16, %[[B_ARG:.+]]: i64, %[[C_ARG:.+]]: i32)
31// Verify signed integer extension and truncation.
32// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i16 to i32
33// CHECK-NEXT:   %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i64 to i32
34// CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
35// CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
36// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
37// CHECK-NEXT: -> tensor<16x32xi32>
38
39
40// -----
41
42// Verifies that cast attributes control the cast operations used.
43func.func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
44  %0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
45                     ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
46                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
47  return %0: tensor<16x32xi32>
48}
49
50// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned
51// CHECK:        = arith.extui
52
53// -----
54
55func.func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
56  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
57                     outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
58  return %0: tensor<16x32xf32>
59}
60
61// CHECK-LABEL: @generalize_matmul_tensor_i16i64f32
62// Verify signed integer to floating point cast.
63// CHECK:        = arith.sitofp
64// CHECK:        = arith.sitofp
65
66// -----
67
68func.func @generalize_matmul_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
69  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
70                              outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
71  return %0: tensor<16x32xi32>
72}
73
74// CHECK-LABEL: @generalize_matmul_tensor_f16f64i32
75// Verify floating point to signed integer cast.
76// CHECK:        = arith.fptosi
77// CHECK:        = arith.fptosi
78
79// -----
80
81func.func @generalize_matmul_unsigned_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
82  %0 = linalg.matmul { cast = #linalg.type_fn<cast_unsigned> }
83                       ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
84                       outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
85  return %0: tensor<16x32xi32>
86}
87
88// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64i32
89// Verify unsigned integer extension and truncation.
90// CHECK:        = arith.extui
91// CHECK:        = arith.trunci
92
93// -----
94
95func.func @generalize_matmul_unsigned_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
96  %0 = linalg.matmul { cast = #linalg.type_fn<cast_unsigned> }
97                       ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
98                       outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
99  return %0: tensor<16x32xf32>
100}
101
102// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64f32
103// Verify unsigned integer to floating point cast.
104// CHECK:        = arith.uitofp
105// CHECK:        = arith.uitofp
106
107// -----
108
109func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
110  %0 = linalg.matmul { cast = #linalg.type_fn<cast_unsigned> }
111                       ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
112                       outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
113  return %0: tensor<16x32xi32>
114}
115
116// CHECK-LABEL: @generalize_matmul_unsigned_tensor_f16f64i32
117// Verify floating point to unsigend integer cast.
118// CHECK:        = arith.fptoui
119// CHECK:        = arith.fptoui
120
121// -----
122
123func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
124    %A: tensor<16x8xf16>,
125    %B: tensor<8x32xf64>,
126    %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
127  %0 = linalg.contract
128      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
129                       affine_map<(d0, d1, d2) -> (d2, d1)>,
130                      affine_map<(d0, d1, d2) -> (d0, d1)>]
131      ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
132      outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
133  return %0: tensor<16x32xf32>
134}
135
136// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
137// CHECK:         ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
138// Verify floating point extension and truncation.
139// CHECK-NEXT:      %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
140// CHECK-NEXT:      %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
141// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
142// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
143// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
144// CHECK-NEXT:    -> tensor<16x32xf32>
145
146// -----
147
148func.func @generalize_matmul_as_contract_with_ext_and_trunc(
149    %A: tensor<24x12xf16>,
150    %B: tensor<12x25xf16>,
151    %C: tensor<24x25xf32>) -> tensor<24x25xf16> {
152  %0 = linalg.contract
153      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
154                       affine_map<(m, n, k) -> (k, n)>,
155                       affine_map<(m, n, k) -> (m, n)>]
156      ins(%A, %B : tensor<24x12xf16>, tensor<12x25xf16>)
157      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
158  %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
159  func.return %1 : tensor<24x25xf16>
160}
161
162// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
163// CHECK:         ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
164// Verify floating point extension and truncation.
165// CHECK-NEXT:      %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
166// CHECK-NEXT:      %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
167// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
168// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
169// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
170// CHECK-NEXT:    -> tensor<24x25xf32>
171// CHECK-NEXT:    %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
172
173// -----
174
175func.func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
176  %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
177    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
178  return %0: tensor<1x2x4x1xf32>
179}
180
181// CHECK-LABEL: @generalize_pooling_nhwc_max_f32
182// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
183// CHECK-NEXT:   %[[MAX:.+]] = arith.maximumf %[[OUT_ARG]], %[[IN_ARG]] : f32
184// CHECK-NEXT:   linalg.yield %[[MAX]] : f32
185// CHECK-NEXT: -> tensor<1x2x4x1xf32>
186
187// -----
188
189func.func @generalize_pooling_nwc_max_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> {
190  %0 = linalg.pooling_nwc_max {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
191    ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32>
192  return %0: tensor<1x4x1xf32>
193}
194
195// CHECK-LABEL: @generalize_pooling_nwc_max_f32
196// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
197// CHECK-NEXT:   %[[MAX:.+]] = arith.maximumf %[[OUT_ARG]], %[[IN_ARG]] : f32
198// CHECK-NEXT:   linalg.yield %[[MAX]] : f32
199// CHECK-NEXT: -> tensor<1x4x1xf32>
200
201// -----
202
203func.func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
204  %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
205    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
206  return %0: tensor<1x2x4x1xi32>
207}
208
209// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
210// Verify signed integer maximum.
211// CHECK:        = arith.maxsi
212
213// -----
214
215func.func @generalize_pooling_nwc_max_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> {
216  %0 = linalg.pooling_nwc_max {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
217    ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32>
218  return %0: tensor<1x4x1xi32>
219}
220
221// CHECK-LABEL: @generalize_pooling_nwc_max_i32
222// Verify signed integer maximum.
223// CHECK:        = arith.maxsi
224
225// -----
226
227func.func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
228  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
229    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
230  return %0: tensor<1x2x4x1xi32>
231}
232
233// CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32
234// Verify unsigned integer minimum.
235// CHECK:        = arith.maxui
236
237// -----
238
239func.func @generalize_pooling_nwc_max_unsigned_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> {
240  %0 = linalg.pooling_nwc_max_unsigned {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
241    ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32>
242  return %0: tensor<1x4x1xi32>
243}
244
245// CHECK-LABEL: @generalize_pooling_nwc_max_unsigned_i32
246// Verify unsigned integer minimum.
247// CHECK:        = arith.maxui
248
249// -----
250
251func.func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
252  %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
253    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
254  return %0: tensor<1x2x4x1xf32>
255}
256
257// CHECK-LABEL: @generalize_pooling_nhwc_min_f32
258// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
259// CHECK-NEXT:   %[[MIN:.+]] = arith.minimumf %[[OUT_ARG]], %[[IN_ARG]] : f32
260// CHECK-NEXT:   linalg.yield %[[MIN]] : f32
261// CHECK-NEXT: -> tensor<1x2x4x1xf32>
262
263// -----
264
265func.func @generalize_pooling_nwc_min_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> {
266  %0 = linalg.pooling_nwc_min {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
267    ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32>
268  return %0: tensor<1x4x1xf32>
269}
270
271// CHECK-LABEL: @generalize_pooling_nwc_min_f32
272// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
273// CHECK-NEXT:   %[[MIN:.+]] = arith.minimumf %[[OUT_ARG]], %[[IN_ARG]] : f32
274// CHECK-NEXT:   linalg.yield %[[MIN]] : f32
275// CHECK-NEXT: -> tensor<1x4x1xf32>
276
277// -----
278
279func.func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
280  %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
281    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
282  return %0: tensor<1x2x4x1xi32>
283}
284
285// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
286// Verify signed integer minimum.
287// CHECK:        = arith.minsi
288
289// -----
290
291func.func @generalize_pooling_nwc_min_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> {
292  %0 = linalg.pooling_nwc_min {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
293    ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32>
294  return %0: tensor<1x4x1xi32>
295}
296
297// CHECK-LABEL: @generalize_pooling_nwc_min_i32
298// Verify signed integer minimum.
299// CHECK:        = arith.minsi
300
301// -----
302
303func.func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
304  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
305    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
306  return %0: tensor<1x2x4x1xi32>
307}
308
309// CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32
310// Verify unsigned integer minimum.
311// CHECK:        = arith.minui
312
313// -----
314
315func.func @generalize_pooling_nwc_min_unsigned_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> {
316  %0 = linalg.pooling_nwc_min_unsigned {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
317    ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32>
318  return %0: tensor<1x4x1xi32>
319}
320
321// CHECK-LABEL: @generalize_pooling_nwc_min_unsigned_i32
322// Verify unsigned integer minimum.
323// CHECK:        = arith.minui
324
325// -----
326
327func.func @generalize_pooling_nhwc_sum_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
328  %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
329    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
330  return %0: tensor<1x2x4x1xf32>
331}
332
333// CHECK-LABEL: @generalize_pooling_nhwc_sum_f32
334// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
335// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[OUT_ARG]], %[[IN_ARG]] : f32
336// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
337// CHECK-NEXT: -> tensor<1x2x4x1xf32>
338
339// -----
340
341func.func @generalize_pooling_nwc_sum_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> {
342  %0 = linalg.pooling_nwc_sum {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
343    ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32>
344  return %0: tensor<1x4x1xf32>
345}
346
347// CHECK-LABEL: @generalize_pooling_nwc_sum_f32
348// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
349// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[OUT_ARG]], %[[IN_ARG]] : f32
350// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
351// CHECK-NEXT: -> tensor<1x4x1xf32>
352
353// -----
354
355func.func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
356  %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
357    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
358  return %0: tensor<1x2x4x1xi32>
359}
360
361// CHECK-LABEL: @generalize_pooling_nhwc_sum_i32
362// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
363// CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[OUT_ARG]], %[[IN_ARG]] : i32
364// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
365// CHECK-NEXT: -> tensor<1x2x4x1xi32>
366
367// -----
368
369func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> {
370  %0 = linalg.pooling_nwc_sum {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>}
371    ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32>
372  return %0: tensor<1x4x1xi32>
373}
374
375// CHECK-LABEL: @generalize_pooling_nwc_sum_i32
376// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
377// CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[OUT_ARG]], %[[IN_ARG]] : i32
378// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
379// CHECK-NEXT: -> tensor<1x4x1xi32>
380
381// -----
382
383func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
384  %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
385  return %0: tensor<f32>
386}
387
388// CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
389
390// CHECK-LABEL: @generalize_fill_0d
391// CHECK:      linalg.generic
392// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
393// CHECK-SAME: iterator_types = []
394
395// -----
396
397func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
398  linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
399  return
400}
401
402// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
403// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
404
405// CHECK-LABEL: @generalize_fill
406// CHECK:      linalg.generic
407// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
408// CHECK-SAME: iterator_types = ["parallel", "parallel"]
409
410// -----
411
412func.func @generalize_index(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
413  %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
414  return %0: tensor<16x32xf32>
415}
416
417// CHECK-LABEL: @generalize_index
418// CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
419// CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
420// CHECK-DAG:    %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
421// CHECK-DAG:    %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
422
423// -----
424
425func.func @generalize_const(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
426  %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
427  return %0: tensor<16x32xf32>
428}
429
430// CHECK-LABEL: @generalize_const
431// CHECK-DAG:    %[[CST0:.+]] = arith.constant 1103515245 : i32
432// CHECK-DAG:    %[[CST1:.+]] = arith.constant 12345 : i32
433// CHECK-DAG:    %[[CST2:.+]] = arith.constant 2.3283063999999999E-10 : f64
434
435// -----
436
437// Verifies the default value of the fun attribute is an exp op.
438func.func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
439  %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
440  return %0: tensor<4x8xf32>
441}
442
443// CHECK-LABEL: @generalize_elemwise_exp
444// CHECK:        = math.exp
445
446// -----
447
448// Verifies the fun attribute controls the unary function used.
449func.func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
450  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
451                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
452  return %0: tensor<4x8xf32>
453}
454
455// CHECK-LABEL: @generalize_elemwise_log
456// CHECK:        = math.log
457
458// -----
459
460// Verifies the fun attribute controls the unary function used.
461func.func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
462  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>}
463                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
464  return %0: tensor<4x8xf32>
465}
466
467// CHECK-LABEL: @generalize_elemwise_abs
468// CHECK:        = math.absf
469
470// -----
471
472// Verifies the fun attribute controls the unary function used.
473func.func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
474  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>}
475                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
476  return %0: tensor<4x8xf32>
477}
478
479// CHECK-LABEL: @generalize_elemwise_ceil
480// CHECK:        = math.ceil
481
482// -----
483
484// Verifies the fun attribute controls the unary function used.
485func.func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
486  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>}
487                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
488  return %0: tensor<4x8xf32>
489}
490
491// CHECK-LABEL: @generalize_elemwise_floor
492// CHECK:        = math.floor
493
494// -----
495
496// Verifies the fun attribute controls the unary function used.
497func.func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
498  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
499                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
500  return %0: tensor<4x8xf32>
501}
502
503// CHECK-LABEL: @generalize_elemwise_negf
504// CHECK:        = arith.negf
505
506// -----
507
508// Verifies the default value of the fun attribute is an add op.
509func.func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
510  %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
511                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
512  return %0: tensor<4x8xf32>
513}
514
515// CHECK-LABEL: @generalize_elemwise_add
516// CHECK:        = arith.addf
517
518// -----
519
520// Verifies the fun attribute controls the binary function used.
521func.func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
522  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
523                              ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
524                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
525  return %0: tensor<4x8xf32>
526}
527
528// CHECK-LABEL: @generalize_elemwise_mul
529// CHECK:        = arith.mulf
530
531// -----
532
533// Verifies pointwise ops support rank zero input tensors
534func.func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
535  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>}
536                              ins(%lhs, %rhs: tensor<f32>, tensor<f32>)
537                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
538  return %0: tensor<4x8xf32>
539}
540
541// CHECK-LABEL: @generalize_elemwise_rank_zero
542// CHECK:       linalg.generic
543// CHECK-SAME:  iterator_types = ["parallel", "parallel"]
544// CHECK:        = arith.subf
545
546// -----
547
548// Verifies the fun attribute controls the binary function used.
549func.func @generalize_copy(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
550  %0 = linalg.copy ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
551  return %0: tensor<4x8xf32>
552}
553
554// CHECK-LABEL: @generalize_copy
555//       CHECK:   linalg.generic
556//  CHECK-NEXT:   ^bb0(%[[I:[0-9a-zA-Z]*]]: f32
557//  CHECK-NEXT:   linalg.yield %[[I]]
558