xref: /llvm-project/mlir/test/Dialect/Linalg/generalize-named-ops.mlir (revision 0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83)
1// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
2
3func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
4  linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>)
5               outs(%C: memref<16x32xf32>)
6  return
7}
8
9
10// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
11// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
12// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
13
14// CHECK: func @generalize_matmul_buffer
15// CHECK-SAME: %[[A:.+]]: memref<16x8xf32>
16// CHECK-SAME: %[[B:.+]]: memref<8x32xf32>
17// CHECK-SAME: %[[C:.+]]: memref<16x32xf32>
18
19// CHECK: linalg.generic
20// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
21// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
22// CHECK-SAME:  ins(%[[A]], %[[B]]
23// CHECK-SAME: outs(%[[C]]
24
25// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
26// CHECK:   %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
27// CHECK:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
28// CHECK:   linalg.yield %[[ADD]] : f32
29
30// -----
31
32func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
33  linalg.matmul indexing_maps = [
34                       affine_map<(d0, d1, d2) -> (d2)>,
35                       affine_map<(d0, d1, d2) -> (d2, d1)>,
36                       affine_map<(d0, d1, d2) -> (d0, d1)>
37                     ]
38                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
39  return
40}
41
42// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
43// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
44// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
45// CHECK-LABEL:   func.func @matmul_bcast_a(
46// CHECK-SAME:                              %[[VAL_0:.*]]: memref<5xf32>,
47// CHECK-SAME:                              %[[VAL_1:.*]]: memref<5x7xf32>,
48// CHECK-SAME:                              %[[VAL_2:.*]]: memref<3x7xf32>) {
49// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
50// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
51// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
52// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
53// CHECK:             linalg.yield %[[VAL_7]] : f32
54// CHECK:           }
55// CHECK:           return
56// CHECK:         }
57
58// -----
59
60func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
61  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
62                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
63  return %0: tensor<16x32xf32>
64}
65
66// CHECK: func @generalize_matmul_tensor
67
68// CHECK: linalg.generic
69// CHECK-SAME:  ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>)
70// CHECK-SAME: outs(%{{.+}} : tensor<16x32xf32>)
71
72// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
73// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
74// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
75// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
76// CHECK-NEXT: -> tensor<16x32xf32>
77
78// -----
79
80func.func @generalize_matmul_tensor_complex(%A : tensor<16x8xcomplex<f32>>,
81                                            %B: tensor<8x32xcomplex<f32>>,
82                                            %C: tensor<16x32xcomplex<f32>>)
83          -> tensor<16x32xcomplex<f32>> {
84  %0 = linalg.matmul ins(%A, %B: tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>)
85                    outs(%C: tensor<16x32xcomplex<f32>>) -> tensor<16x32xcomplex<f32>>
86  return %0: tensor<16x32xcomplex<f32>>
87}
88
89// CHECK: func @generalize_matmul_tensor_complex
90
91// CHECK: linalg.generic
92// CHECK-SAME:  ins(%{{.+}}, %{{.+}} : tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>)
93// CHECK-SAME: outs(%{{.+}} : tensor<16x32xcomplex<f32>>)
94
95// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: complex<f32>, %[[B_ARG:.+]]: complex<f32>, %[[C_ARG:.+]]: complex<f32>)
96// CHECK-NEXT:   %[[MUL:.+]] = complex.mul %[[A_ARG]], %[[B_ARG]] : complex<f32>
97// CHECK-NEXT:   %[[ADD:.+]] = complex.add %[[C_ARG]], %[[MUL]] : complex<f32>
98// CHECK-NEXT:   linalg.yield %[[ADD]] : complex<f32>
99// CHECK-NEXT: -> tensor<16x32xcomplex<f32>>
100
101// -----
102
103func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
104  linalg.depthwise_conv_2d_nhwc_hwcm
105     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
106     ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
107    outs(%output : memref<2x3x4x2x3xf32>)
108  return
109}
110
111// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
112// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
113// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
114
115// CHECK: func @depthwise_conv_2d_nhwc_hwcm
116
117// CHECK: linalg.generic
118// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
119// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
120// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
121// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
122
123// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
124// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
125// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
126// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
127
128// -----
129
130func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) {
131  linalg.depthwise_conv_2d_nhwc_hwcm
132     { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
133     ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
134    outs(%output : memref<2x2x3x2x3xf32>)
135  return
136}
137
138// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5 * 2, d2 + d6 * 2, d3)>
139// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
140// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
141
142// CHECK: func @depthwise_conv_2d_nhwc_hwcm
143
144// CHECK: linalg.generic
145// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
146// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
147// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
148// CHECK-SAME: outs(%{{.+}} : memref<2x2x3x2x3xf32>)
149
150// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
151// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
152// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
153// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
154
155// -----
156
157func.func @depthwise_conv_2d_nhwc_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
158  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
159    ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
160    outs(%output: memref<1x56x56x96xf32>)
161  return
162}
163
164// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
165// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
166// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
167
168// CHECK: func @depthwise_conv_2d_nhwc_hwc
169
170// CHECK: linalg.generic
171// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
172// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
173// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
174// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
175
176// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
177// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
178// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
179// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
180
181// -----
182
183func.func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
184  linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
185                                       strides = dense<1> : tensor<1xi64>}
186     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
187    outs (%output: memref<?x?x?xf32>)
188  return
189}
190// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
191// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
192// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
193
194// CHECK: func @conv_1d_nwc_wcf
195
196// CHECK: linalg.generic
197// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
198// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
199// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
200// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
201
202// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
203// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
204// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
205// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
206
207// -----
208
209func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
210  linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>,
211                                       strides = dense<1> : tensor<1xi64>}
212     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
213    outs (%output: memref<?x?x?xf32>)
214  return
215}
216// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
217// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
218// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
219
220// CHECK: func @conv_1d_ncw_fcw
221
222// CHECK: linalg.generic
223// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
224// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
225// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
226// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
227
228// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
229// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
230// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
231// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
232
233// -----
234
235func.func @conv_2d_ngchw_gfchw_q(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xi32>) {
236  linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
237                                       strides = dense<1> : tensor<2xi64>}
238     ins (%input, %filter, %inputzp, %filterzp: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
239    outs (%output: memref<?x?x?x?x?xi32>)
240  return
241}
242// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
243// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
244// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
245// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
246
247// CHECK: func @conv_2d_ngchw_gfchw_q
248
249// CHECK: linalg.generic
250// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP3]]]
251// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
252// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
253// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xi32>)
254
255// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i32, %[[BBARG3:.+]]: i32, %[[BBARG4:.+]]: i32)
256// CHECK-NEXT:      %[[EXTSI0:.+]] = arith.extsi %[[BBARG0]] : i8 to i32
257// CHECK-NEXT:      %[[SUB0:.+]] = arith.subi %[[EXTSI0]], %[[BBARG2]] : i32
258// CHECK-NEXT:      %[[EXTSI1:.+]] = arith.extsi %[[BBARG1]] : i8 to i32
259// CHECK-NEXT:      %[[SUB1:.+]] = arith.subi %[[EXTSI1]], %[[BBARG3]] : i32
260// CHECK-NEXT:      %[[MUL:.+]] = arith.muli %[[SUB0]], %[[SUB1]] : i32
261// CHECK-NEXT:      %[[ADD:.+]] = arith.addi %[[BBARG4]], %[[MUL]] : i32
262// CHECK-NEXT:      linalg.yield %[[ADD]] : i32
263
264// -----
265
266func.func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
267  linalg.fill ins(%value : f32) outs(%output : memref<?x?xf32>)
268  return
269}
270
271// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
272// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
273
274// CHECK: func @generalize_fill
275// CHECK-SAME: (%[[ARG0:.+]]: memref<?x?xf32>, %[[VAL:.+]]: f32)
276
277// CHECK: linalg.generic
278// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
279// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
280// CHECK-SAME: ins(%[[VAL]] : f32)
281// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
282
283// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
284// CHECK-NEXT:      linalg.yield %[[BBARG0]] : f32
285
286// -----
287
288func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi8>,  %out: memref<?x?xf32>) {
289  linalg.batch_matvec ins(%lhs, %rhs: memref<?x?x?xi8>, memref<?x?xi8>)
290                     outs(%out: memref<?x?xf32>)
291  return
292}
293// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
294// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
295// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
296
297// CHECK: @generalize_batch_matm_vec
298
299// CHECK: linalg.generic
300// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
301// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
302// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xi8>, memref<?x?xi8>)
303// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
304// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
305// CHECK:            %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
306// CHECK:            %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
307// CHECK:            %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
308// CHECK:            %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
309// CHECK:            linalg.yield %[[ADD]] : f32
310
311// -----
312
313func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>,  %out: memref<?x?xf32>) {
314  linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
315                     outs(%out: memref<?x?xf32>)
316  return
317}
318// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
319// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
320// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
321
322// CHECK: @generalize_batch_vecmat
323
324// CHECK: linalg.generic
325// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
326// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
327// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
328// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
329// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
330// CHECK:            %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
331// CHECK:            %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
332// CHECK:            %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
333// CHECK:            %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
334// CHECK:            linalg.yield %[[ADD]] : f32
335
336// -----
337
338func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
339  linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
340                             outs(%out: memref<8x8xf32>)
341  return
342}
343
344// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
345// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
346// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
347
348// CHECK: @batch_reduce_gemm
349
350// CHECK: linalg.generic
351// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
352// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
353// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xf32>, memref<7x9x8xf32>)
354// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
355// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
356// CHECK:         %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
357// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
358// CHECK:         linalg.yield %[[ADD]] : f32
359
360// -----
361
362func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
363  linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
364                             outs(%out: memref<8x8xf32>)
365  return
366}
367
368// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
369// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
370// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
371
372// CHECK: @generalize_batch_reduce_gemm_bf16
373
374// CHECK: linalg.generic
375// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
376// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
377// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
378// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
379// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
380// CHECK:         %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
381// CHECK:         %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
382// CHECK:         %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
383// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
384// CHECK:         linalg.yield %[[ADD]] : f32
385
386
387// -----
388
389// CHECK-LABEL: generalize_linalg_map
390func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
391  %cst = arith.constant 0.000000e+00 : f32
392  // CHECK: linalg.map
393  // CHECK-NOT: linalg.generic
394  linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
395    () {
396      linalg.yield %cst : f32
397    }
398  return
399}
400
401// -----
402
403func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
404                          %out: memref<7x14x21xf32>) {
405  linalg.add ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
406             outs(%out : memref<7x14x21xf32>)
407  return
408}
409
410// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
411
412// CHECK: func @generalize_add
413// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
414// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
415
416// CHECK: linalg.generic
417// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
418// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
419// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
420// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
421
422// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
423// CHECK-NEXT:      %[[SUM:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
424// CHECK-NEXT:      linalg.yield %[[SUM]] : f32
425
426// -----
427
428func.func @generalize_sub(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
429                          %out: memref<7x14x21xf32>) {
430  linalg.sub ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
431             outs(%out : memref<7x14x21xf32>)
432  return
433}
434
435// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
436
437// CHECK: func @generalize_sub
438// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
439// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
440
441// CHECK: linalg.generic
442// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
443// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
444// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
445// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
446
447// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
448// CHECK-NEXT:      %[[SUB:.+]] = arith.subf %[[BBARG0]], %[[BBARG1]] : f32
449// CHECK-NEXT:      linalg.yield %[[SUB]] : f32
450
451// -----
452
453func.func @generalize_mul(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
454                          %out: memref<7x14x21xf32>) {
455  linalg.mul ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
456             outs(%out : memref<7x14x21xf32>)
457  return
458}
459
460// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
461
462// CHECK: func @generalize_mul
463// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
464// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
465
466// CHECK: linalg.generic
467// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
468// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
469// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
470// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
471
472// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
473// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
474// CHECK-NEXT:      linalg.yield %[[MUL]] : f32
475
476// -----
477
478func.func @generalize_div(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
479                          %out: memref<7x14x21xf32>) {
480  linalg.div ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
481             outs(%out : memref<7x14x21xf32>)
482  return
483}
484
485// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
486
487// CHECK: func @generalize_div
488// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
489// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
490
491// CHECK: linalg.generic
492// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
493// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
494// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
495// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
496
497// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
498// CHECK-NEXT:      %[[DIV:.+]] = arith.divf %[[BBARG0]], %[[BBARG1]] : f32
499// CHECK-NEXT:      linalg.yield %[[DIV]] : f32
500
501// -----
502
503func.func @generalize_divu(%lhs: memref<7x14x21xi32>, %rhs: memref<7x14x21xi32>,
504                          %out: memref<7x14x21xi32>) {
505  linalg.div_unsigned ins(%lhs, %rhs : memref<7x14x21xi32>, memref<7x14x21xi32>)
506             outs(%out : memref<7x14x21xi32>)
507  return
508}
509
510// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
511
512// CHECK: func @generalize_divu
513// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xi32>, %[[RHS:.+]]: memref<7x14x21xi32>,
514// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xi32>)
515
516// CHECK: linalg.generic
517// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
518// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
519// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xi32>, memref<7x14x21xi32>)
520// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xi32>)
521
522// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32)
523// CHECK-NEXT:      %[[DIVU:.+]] = arith.divui %[[BBARG0]], %[[BBARG1]] : i32
524// CHECK-NEXT:      linalg.yield %[[DIVU]] : i32
525
526// -----
527
528func.func @generalize_exp(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
529  linalg.exp ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
530  return
531}
532
533// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
534
535// CHECK: func @generalize_exp
536// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
537
538// CHECK: linalg.generic
539// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
540// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
541// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
542
543// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
544// CHECK-NEXT:      %[[EXP:.+]] = math.exp %[[BBARG0]] : f32
545// CHECK-NEXT:      linalg.yield %[[EXP]] : f32
546
547// -----
548
549func.func @generalize_log(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
550  linalg.log ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
551  return
552}
553
554// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
555
556// CHECK: func @generalize_log
557// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
558
559// CHECK: linalg.generic
560// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
561// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
562// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
563
564// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
565// CHECK-NEXT:      %[[log:.+]] = math.log %[[BBARG0]] : f32
566// CHECK-NEXT:      linalg.yield %[[log]] : f32
567
568// -----
569
570func.func @generalize_abs(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
571  linalg.abs ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
572  return
573}
574
575// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
576
577// CHECK: func @generalize_abs
578// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
579
580// CHECK: linalg.generic
581// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
582// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
583// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
584
585// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
586// CHECK-NEXT:      %[[abs:.+]] = math.absf %[[BBARG0]] : f32
587// CHECK-NEXT:      linalg.yield %[[abs]] : f32
588
589// -----
590
591func.func @generalize_ceil(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
592  linalg.ceil ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
593  return
594}
595
596// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
597
598// CHECK: func @generalize_ceil
599// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
600
601// CHECK: linalg.generic
602// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
603// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
604// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
605
606// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
607// CHECK-NEXT:      %[[ceil:.+]] = math.ceil %[[BBARG0]] : f32
608// CHECK-NEXT:      linalg.yield %[[ceil]] : f32
609
610// -----
611
612func.func @generalize_floor(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
613  linalg.floor ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
614  return
615}
616
617// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
618
619// CHECK: func @generalize_floor
620// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
621
622// CHECK: linalg.generic
623// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
624// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
625// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
626
627// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
628// CHECK-NEXT:      %[[floor:.+]] = math.floor %[[BBARG0]] : f32
629// CHECK-NEXT:      linalg.yield %[[floor]] : f32
630
631// -----
632
633func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
634  linalg.negf ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
635  return
636}
637
638// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
639
640// CHECK: func @generalize_negf
641// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
642
643// CHECK: linalg.generic
644// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
645// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
646// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
647
648// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
649// CHECK-NEXT:      %[[negf:.+]] = arith.negf %[[BBARG0]] : f32
650// CHECK-NEXT:      linalg.yield %[[negf]] : f32
651
652// -----
653
654func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
655  linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
656  return
657}
658
659// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
660
661// CHECK: func @generalize_reciprocal
662// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
663
664// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32
665
666// CHECK: linalg.generic
667// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
668// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
669// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
670
671// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
672// CHECK-NEXT:      %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32
673// CHECK-NEXT:      linalg.yield %[[reciprocal]] : f32
674
675// -----
676
677func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
678  linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
679  return
680}
681
682// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
683
684// CHECK: func @generalize_round
685// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
686
687// CHECK: linalg.generic
688// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
689// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
690// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
691
692// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
693// CHECK-NEXT:      %[[round:.+]] = math.round %[[BBARG0]] : f32
694// CHECK-NEXT:      linalg.yield %[[round]] : f32
695
696// -----
697
698func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
699  linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
700  return
701}
702
703// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
704
705// CHECK: func @generalize_sqrt
706// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
707
708// CHECK: linalg.generic
709// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
710// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
711// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
712
713// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
714// CHECK-NEXT:      %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32
715// CHECK-NEXT:      linalg.yield %[[sqrt]] : f32
716
717// -----
718
719func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
720  linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
721  return
722}
723
724// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
725
726// CHECK: func @generalize_rsqrt
727// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
728
729// CHECK: linalg.generic
730// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
731// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
732// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
733
734// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
735// CHECK-NEXT:      %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32
736// CHECK-NEXT:      linalg.yield %[[rsqrt]] : f32
737
738// -----
739
740func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
741  linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
742  return
743}
744
745// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
746
747// CHECK: func @generalize_square
748// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
749
750// CHECK: linalg.generic
751// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
752// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
753// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
754
755// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
756// CHECK-NEXT:      %[[square:.+]] = arith.mulf %[[BBARG0]], %[[BBARG0]] : f32
757// CHECK-NEXT:      linalg.yield %[[square]] : f32
758
759// -----
760
761func.func @generalize_tanh(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
762  linalg.tanh ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
763  return
764}
765
766// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
767
768// CHECK: func @generalize_tanh
769// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
770
771// CHECK: linalg.generic
772// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
773// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
774// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
775
776// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
777// CHECK-NEXT:      %[[tanh:.+]] = math.tanh %[[BBARG0]] : f32
778// CHECK-NEXT:      linalg.yield %[[tanh]] : f32
779
780// -----
781
782func.func @generalize_erf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
783  linalg.erf ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
784  return
785}
786
787// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
788
789// CHECK: func @generalize_erf
790// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
791
792// CHECK: linalg.generic
793// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
794// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
795// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
796
797// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
798// CHECK-NEXT:      %[[erf:.+]] = math.erf %[[BBARG0]] : f32
799// CHECK-NEXT:      linalg.yield %[[erf]] : f32
800
801// -----
802
803func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
804                          %out: memref<7x14x21xf32>) {
805  linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
806             outs(%out : memref<7x14x21xf32>)
807  return
808}
809
810// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
811
812// CHECK: func @generalize_max
813// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
814// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
815
816// CHECK: linalg.generic
817// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
818// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
819// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
820// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
821
822// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
823// CHECK-NEXT:      %[[max:.+]] = arith.maximumf %[[BBARG0]], %[[BBARG1]] : f32
824// CHECK-NEXT:      linalg.yield %[[max]] : f32
825
826// -----
827
828func.func @generalize_min(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
829                          %out: memref<7x14x21xf32>) {
830  linalg.min ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
831             outs(%out : memref<7x14x21xf32>)
832  return
833}
834
835// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
836
837// CHECK: func @generalize_min
838// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
839// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
840
841// CHECK: linalg.generic
842// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
843// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
844// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
845// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
846
847// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
848// CHECK-NEXT:      %[[min:.+]] = arith.minimumf %[[BBARG0]], %[[BBARG1]] : f32
849// CHECK-NEXT:      linalg.yield %[[min]] : f32
850
851
852// -----
853
854func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
855                          %out: memref<7x14x21xf32>) {
856  linalg.powf ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
857             outs(%out : memref<7x14x21xf32>)
858  return
859}
860
861// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
862
863// CHECK: func @generalize_powf
864// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
865// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
866
867// CHECK: linalg.generic
868// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
869// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
870// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
871// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
872
873// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
874// CHECK-NEXT:      %[[powf:.+]] = math.powf %[[BBARG0]], %[[BBARG1]] : f32
875// CHECK-NEXT:      linalg.yield %[[powf]] : f32
876
877
878// -----
879
880func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
881                              %out: memref<7x14x21xf32>) {
882  linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
883                outs(%out: memref<7x14x21xf32>)
884  return
885}
886
887// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
888
889// CHECK: func @generalize_select
890// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
891// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
892
893// CHECK: linalg.generic
894// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
895// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
896// CHECK-SAME:  ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
897// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
898
899// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
900// CHECK-NEXT:      %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
901// CHECK-NEXT:      linalg.yield %[[select]] : f32
902
903
904// -----
905
906// CHECK-LABEL: func @fill_tensor
907func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
908  %e0 = tensor.empty() : tensor<f32>
909  %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
910// CHECK: linalg.generic
911// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
912// CHECK-NEXT:      linalg.yield %[[BBARG0]] : f32
913
914  %e1 = tensor.empty() : tensor<vector<2x4xf32>>
915  %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
916// CHECK: linalg.generic
917// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: vector<2x4xf32>, %[[BBARG1:.+]]: vector<2x4xf32>)
918// CHECK-NEXT:      linalg.yield %[[BBARG0]] : vector<2x4xf32>
919
920  return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
921}
922
923// -----
924
925// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
926// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
927// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
928
929// CHECK-LABEL:   func.func @matmul_transpose_a_explicit(
930// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5x3xf32>,
931// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<5x7xf32>,
932// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
933
934// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
935// CHECK:           arith.mulf
936// CHECK:           arith.addf
937
938func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
939  linalg.matmul indexing_maps = [
940                       affine_map<(d0, d1, d2) -> (d2, d0)>,
941                       affine_map<(d0, d1, d2) -> (d2, d1)>,
942                       affine_map<(d0, d1, d2) -> (d0, d1)>
943                      ]
944                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
945                      outs(%arg2: memref<3x7xf32>)
946  return
947}
948
949// -----
950
951// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
952// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
953// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
954// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
955// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
956// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
957// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
958
959// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
960// CHECK:           arith.mulf
961// CHECK:           arith.addf
962
963func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
964  linalg.matmul indexing_maps = [
965                       affine_map<(d0, d1, d2) -> (d0, d2)>,
966                       affine_map<(d0, d1, d2) -> (d1, d2)>,
967                       affine_map<(d0, d1, d2) -> (d0, d1)>
968                      ]
969                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
970                      outs(%arg2: memref<3x7xf32>)
971  return
972}
973
974// -----
975
976// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
977// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
978// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
979
980// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
981// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
982// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
983// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
984
985// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
986// CHECK:           arith.mulf
987// CHECK:           arith.addf
988
989func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
990  linalg.matmul indexing_maps = [
991                       affine_map<(d0, d1, d2) -> (d2, d0)>,
992                       affine_map<(d0, d1, d2) -> (d1, d2)>,
993                       affine_map<(d0, d1, d2) -> (d0, d1)>
994                      ]
995                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
996                      outs(%arg2: memref<3x7xf32>)
997  return
998}
999
1000// -----
1001
1002// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1003// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1004// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1005
1006// CHECK-LABEL: func.func @contract_matmul(
1007// CHECK-SAME:      %[[A:.*]]: memref<3x5xf32>,
1008// CHECK-SAME:      %[[B:.*]]: memref<5x7xf32>,
1009// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
1010
1011// CHECK:         linalg.generic
1012// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1013// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
1014// CHECK-NEXT:    ^{{.+}}(
1015// CHECK-NEXT:      arith.mulf
1016// CHECK-NEXT:      arith.addf
1017// CHECK-NEXT:      linalg.yield
1018
1019func.func @contract_matmul(%A: memref<3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
1020  linalg.contract
1021      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
1022                       affine_map<(d0, d1, d2) -> (d2, d1)>,
1023                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1024      ins(%A, %B : memref<3x5xf32>, memref<5x7xf32>)
1025      outs(%C: memref<3x7xf32>)
1026
1027  return
1028}
1029
1030// -----
1031
1032// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1033// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1034// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1035
1036// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
1037// CHECK-SAME:      %[[A:.*]]: memref<5x3xf32>,
1038// CHECK-SAME:      %[[B:.*]]: memref<7x5xf32>,
1039// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
1040
1041// CHECK:         linalg.generic
1042// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1043// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
1044// CHECK-NEXT:    ^{{.+}}(
1045// CHECK-NEXT:      arith.mulf
1046// CHECK-NEXT:      arith.addf
1047// CHECK-NEXT:      linalg.yield
1048
1049func.func @contract_matmul_transpose_a_b(%A: memref<5x3xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
1050  linalg.contract
1051      indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
1052                       affine_map<(d0, d1, d2) -> (d1, d2)>,
1053                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1054      ins(%A, %B : memref<5x3xf32>, memref<7x5xf32>)
1055      outs(%C: memref<3x7xf32>)
1056  return
1057}
1058
1059// -----
1060
1061// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1062// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1063// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1064
1065// CHECK-LABEL: func.func @contract_batch_matmul(
1066// CHECK-SAME:      %[[A:.*]]: memref<9x3x5xf32>,
1067// CHECK-SAME:      %[[B:.*]]: memref<9x5x7xf32>,
1068// CHECK-SAME:      %[[C:.*]]: memref<9x3x7xf32>) {
1069
1070// CHECK:         linalg.generic
1071// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1072// CHECK-SAME:        iterator_types = ["parallel", "parallel", "parallel", "reduction"]
1073// CHECK-NEXT:    ^{{.+}}(
1074// CHECK-NEXT:      arith.mulf
1075// CHECK-NEXT:      arith.addf
1076// CHECK-NEXT:      linalg.yield
1077
1078func.func @contract_batch_matmul(%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<9x3x7xf32>) {
1079  linalg.contract
1080      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1081                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1082                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
1083      ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
1084      outs(%C: memref<9x3x7xf32>)
1085  return
1086}
1087
1088// -----
1089
1090// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1091// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1092// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1093
1094// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
1095// CHECK-SAME:      %[[A:.*]]: memref<9x3x5xf32>,
1096// CHECK-SAME:      %[[B:.*]]: memref<9x5x7xf32>,
1097// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
1098
1099// CHECK:         linalg.generic
1100// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1101// CHECK-SAME:        iterator_types = ["reduction", "parallel", "parallel", "reduction"]
1102// CHECK-NEXT:    ^{{.+}}(
1103// CHECK-NEXT:      arith.mulf
1104// CHECK-NEXT:      arith.addf
1105// CHECK-NEXT:      linalg.yield
1106
1107#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1108#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1109#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1110func.func @contract_batch_reduce_matmul(
1111    %A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<3x7xf32>) {
1112  linalg.contract
1113      indexing_maps = [#accessA, #accessB, #accessC]
1114      ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
1115      outs(%C: memref<3x7xf32>)
1116  return
1117}
1118
1119// -----
1120
1121// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
1122// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
1123// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1124
1125// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
1126// CHECK-SAME:      %[[A:.*]]: memref<9x5x3xf32>,
1127// CHECK-SAME:      %[[B:.*]]: memref<9x7x5xf32>,
1128// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
1129
1130// CHECK:         linalg.generic
1131// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1132// CHECK-SAME:        iterator_types = ["reduction", "parallel", "parallel", "reduction"]
1133// CHECK-NEXT:    ^{{.+}}(
1134// CHECK-NEXT:      arith.mulf
1135// CHECK-NEXT:      arith.addf
1136// CHECK-NEXT:      linalg.yield
1137
1138#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
1139#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
1140#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1141func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
1142    %A: memref<9x5x3xf32>, %B: memref<9x7x5xf32>, %C: memref<3x7xf32>) {
1143  linalg.contract
1144      indexing_maps = [#accessA, #accessB, #accessC]
1145      ins(%A, %B : memref<9x5x3xf32>, memref<9x7x5xf32>)
1146      outs(%C: memref<3x7xf32>)
1147  return
1148}
1149
1150// -----
1151
1152// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0) -> (d0)>
1153// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0) -> ()>
1154
1155// CHECK-LABEL: func.func @contract_dot(
1156// CHECK-SAME:      %[[A:.*]]: memref<9xf32>, %[[B:.*]]: memref<9xf32>,
1157// CHECK-SAME:      %[[C:.*]]: memref<f32>) {
1158
1159// CHECK:         linalg.generic
1160// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
1161// CHECK-SAME:        iterator_types = ["reduction"]
1162// CHECK-NEXT:    ^{{.+}}(
1163// CHECK-NEXT:      arith.mulf
1164// CHECK-NEXT:      arith.addf
1165// CHECK-NEXT:      linalg.yield
1166
1167#accessAB = affine_map<(d0) -> (d0)>
1168#accessC = affine_map<(d0) -> ()>
1169func.func @contract_dot(
1170    %A: memref<9xf32>, %B: memref<9xf32>, %C: memref<f32>) {
1171  linalg.contract
1172      indexing_maps = [#accessAB, #accessAB, #accessC]
1173      ins(%A, %B : memref<9xf32>, memref<9xf32>)
1174      outs(%C: memref<f32>)
1175  return
1176}
1177
1178// -----
1179
1180// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1181// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1182
1183// CHECK-LABEL: func.func @contract_matmul_bcast_a_b(
1184// CHECK-SAME:      %[[A:.*]]: memref<5xf32>, %[[B:.*]]: memref<5xf32>,
1185// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
1186
1187// CHECK:         linalg.generic
1188// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
1189// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
1190// CHECK-NEXT:    ^{{.+}}(
1191// CHECK-NEXT:      arith.mulf
1192// CHECK-NEXT:      arith.addf
1193// CHECK-NEXT:      linalg.yield
1194
1195#accessAB = affine_map<(d0, d1, d2) -> (d2)>
1196#accessC = affine_map<(d0, d1, d2) -> (d0, d1)>
1197func.func @contract_matmul_bcast_a_b(
1198    %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
1199  linalg.contract
1200      indexing_maps = [#accessAB, #accessAB, #accessC]
1201      ins(%A, %B : memref<5xf32>, memref<5xf32>)
1202      outs(%C: memref<3x7xf32>)
1203  return
1204}
1205