xref: /llvm-project/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (revision 47bc565ca7990a2de20af4030baf08ac62739aca)
1// RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s
2
3func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
4  %arg2 : memref<?x?xf32>) {
5  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
6      outs(%arg2 : memref<?x?xf32>)
7  return
8}
9
10module attributes {transform.with_named_sequence} {
11  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
12    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
13      : (!transform.any_op) -> !transform.any_op
14    %0 = transform.structured.convert_to_loops %matmul
15      : (!transform.any_op) -> (!transform.any_op)
16    transform.yield
17  }
18}
19// CHECK-LABEL: func @gemm
20//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
21//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
22//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
23//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
24//   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
25//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
26//   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
27//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
28//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
29//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
30//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
31//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
32//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
33//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
34//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
35//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
36//       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
37//   CHECK-NOT:   linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
38
39// -----
40
41func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
42  %arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
43  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
44      outs(%arg2 : memref<?x?xf32>)
45  linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
46      outs(%arg4 : memref<?xf32>)
47  return
48}
49
50module attributes {transform.with_named_sequence} {
51  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
52    %linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
53      : (!transform.any_op) -> !transform.any_op
54    %0 = transform.structured.convert_to_loops %linalg_ops
55      : (!transform.any_op) -> (!transform.any_op)
56    %1:5 = transform.split_handle %0
57      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
58    transform.yield
59  }
60}
61// CHECK-LABEL: func @gemm
62//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
63//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
64//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
65//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
66//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
67//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
68//   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
69//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
70//   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
71//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
72//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
73//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
74//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
75//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
76//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
77//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
78//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
79//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
80//       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
81//       CHECK:   scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
82//       CHECK:     scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
83//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
84//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
85//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
86//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
87//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
88//       CHECK:         memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]
89
90// -----
91
92func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
93    %arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
94  linalg.generic {
95      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>,
96                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>],
97      iterator_types = ["parallel", "parallel"]}
98      ins(%arg0, %arg1, %arg2 : memref<200x300xi32>, memref<300xi16>, memref<200xi8>)
99      outs(%arg3 : memref<300x200xi64>) {
100    ^bb0(%b0 : i32, %b1 : i16, %b2 : i8, %b3 : i64):
101      %0 = linalg.index 0 : index
102      %1 = arith.index_cast %0 : index to i16
103      %2 = arith.muli %b1, %1 : i16
104      %3 = linalg.index 1 : index
105      %4 = arith.index_cast %3 : index to i8
106      %5 = arith.muli %b2, %4 : i8
107      %6 = arith.extsi %2 : i16 to i32
108      %7 = arith.extsi %5 : i8 to i32
109      %8 = arith.addi %6, %7 : i32
110      %9 = arith.addi %8, %b0 : i32
111      %10 = arith.extsi %9 : i32 to i64
112      linalg.yield %10 : i64
113    }
114  return
115}
116
117module attributes {transform.with_named_sequence} {
118  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
119    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
120      : (!transform.any_op) -> !transform.any_op
121    %0 = transform.structured.convert_to_loops %generic
122      : (!transform.any_op) -> (!transform.any_op)
123    transform.yield
124  }
125}
126// CHECK-LABEL: func @indexed_generic
127//  CHECK-SAME:     %[[ARG0:.+]]: memref<200x300xi32>
128//  CHECK-SAME:     %[[ARG1:.+]]: memref<300xi16>
129//  CHECK-SAME:     %[[ARG2:.+]]: memref<200xi8>
130//  CHECK-SAME:     %[[ARG3:.+]]: memref<300x200xi64>
131//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
132//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
133//   CHECK-DAG:   %[[C200:.+]] = arith.constant 200 : index
134//   CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
135//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C200]] step %[[C1]]
136//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C1]]
137//   CHECK-DAG:       %[[B0:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV1]]]
138//   CHECK-DAG:       %[[B1:.+]] = memref.load %[[ARG1]][%[[IV1]]]
139//   CHECK-DAG:       %[[B2:.+]] = memref.load %[[ARG2]][%[[IV0]]]
140//       CHECK:       %[[T1:.+]] = arith.index_cast %[[IV0]]
141//       CHECK:       %[[T2:.+]] = arith.muli %[[B1]], %[[T1]]
142//       CHECK:       %[[T4:.+]] = arith.index_cast %[[IV1]]
143//       CHECK:       %[[T5:.+]] = arith.muli %[[B2]], %[[T4]]
144//       CHECK:       %[[T6:.+]] = arith.extsi %[[T2]]
145//       CHECK:       %[[T7:.+]] = arith.extsi %[[T5]]
146//       CHECK:       %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
147//       CHECK:       %[[T9:.+]] = arith.addi %[[T8]], %[[B0]]
148//       CHECK:       %[[T10:.+]] = arith.extsi %[[T9]]
149//       CHECK:       memref.store %[[T10]], %[[ARG3]][%[[IV1]], %[[IV0]]]
150
151// -----
152
153func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
154  %arg2 : memref<?x?x?x?xf32>) {
155  linalg.conv_2d_nhwc_hwcf {
156      strides = dense<[1, 2]> : tensor<2xi64>,
157      dilations = dense<[3, 4]> : tensor<2xi64>}
158      ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
159      outs(%arg2 : memref<?x?x?x?xf32>)
160  return
161}
162
163module attributes {transform.with_named_sequence} {
164  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
165    %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
166      : (!transform.any_op) -> !transform.any_op
167    %0 = transform.structured.convert_to_loops %conv
168      : (!transform.any_op) -> (!transform.any_op)
169    transform.yield
170  }
171}
172
173//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
174//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
175//       CHECK: func @conv_strides_and_dilation(
176//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
177//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
178//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
179//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
180//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
181//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
182//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
183//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
184//   CHECK-DAG:   %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
185//   CHECK-DAG:   %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
186//   CHECK-DAG:   %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
187//   CHECK-DAG:   %[[F:.+]] = memref.dim %[[ARG1]], %[[C3]]
188//   CHECK-DAG:   %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
189//   CHECK-DAG:   %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
190//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
191//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
192//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
193//       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[F]] step %[[C1]]
194//       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
195//       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
196//       CHECK:               scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
197//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
198//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
199//   CHECK-DAG:                 %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
200//   CHECK-DAG:                 %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
201//   CHECK-DAG:                 %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
202//       CHECK:                 %[[T12:.+]] = arith.mulf %[[T9]], %[[T10]]
203//       CHECK:                 %[[T13:.+]] = arith.addf %[[T11]], %[[T12]]
204//       CHECK:                 memref.store %[[T13]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
205
206// -----
207
208func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?xf32>,
209  %arg2 : memref<?x?x?x?xf32>) {
210  linalg.pooling_nhwc_max {
211      strides = dense<[1, 2]> : tensor<2xi64>,
212      dilations = dense<[3, 4]> : tensor<2xi64>}
213      ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?xf32>)
214      outs(%arg2 : memref<?x?x?x?xf32>)
215  return
216}
217
218module attributes {transform.with_named_sequence} {
219  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
220    %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
221      : (!transform.any_op) -> !transform.any_op
222    %0 = transform.structured.convert_to_loops %pool
223      : (!transform.any_op) -> (!transform.any_op)
224    transform.yield
225  }
226}
227
228//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
229//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
230//       CHECK: func @pool_strides_and_dilation
231//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
232//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
233//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
234//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
235//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
236//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
237//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
238//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
239//   CHECK-DAG:   %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
240//   CHECK-DAG:   %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
241//   CHECK-DAG:   %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
242//   CHECK-DAG:   %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
243//   CHECK-DAG:   %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
244//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
245//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
246//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
247//       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
248//       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
249//       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
250//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
251//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
252//   CHECK-DAG:               %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
253//   CHECK-DAG:               %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
254//       CHECK:               %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]]
255//       CHECK:               memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
256
257// -----
258
259func.func @map(%lhs: memref<64xf32>,
260    %rhs: memref<64xf32>, %out: memref<64xf32>) {
261  linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
262             outs(%out : memref<64xf32>)
263    (%in: f32, %in_0: f32) {
264      %0 = arith.addf %in, %in_0 : f32
265      linalg.yield %0 : f32
266    }
267  return
268}
269
270module attributes {transform.with_named_sequence} {
271  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
272    %map = transform.structured.match ops{["linalg.map"]} in %arg1
273      : (!transform.any_op) -> !transform.any_op
274    %0 = transform.structured.convert_to_loops %map
275      : (!transform.any_op) -> (!transform.any_op)
276    transform.yield
277  }
278}
279// CHECK-LABEL: func.func @map(
280// CHECK-SAME:    %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
281// CHECK-SAME:    %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
282// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<64xf32>) {
283
284// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
285// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
286// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
287
288// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
289// CHECK:       %[[LHS_ELEM:.*]] = memref.load %[[LHS]][%[[I]]]
290// CHECK:       %[[RHS_ELEM:.*]] = memref.load %[[RHS]][%[[I]]]
291// CHECK:       %[[ADD:.*]] = arith.addf %[[LHS_ELEM]], %[[RHS_ELEM]]
292// CHECK:       memref.store %[[ADD]], %[[OUT]][%[[I]]]
293
294// -----
295
296func.func @transpose(%arg0: memref<16x32x64xf32>,
297                               %arg1: memref<32x64x16xf32>) {
298  linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
299                   outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
300  return
301}
302
303module attributes {transform.with_named_sequence} {
304  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
305    %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
306      : (!transform.any_op) -> !transform.any_op
307    %0 = transform.structured.convert_to_loops %transpose
308      : (!transform.any_op) -> (!transform.any_op)
309    transform.yield
310  }
311}
312// CHECK-LABEL: func.func @transpose(
313// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
314// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>)
315
316// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
317// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
318// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
319// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
320// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
321
322// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
323// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
324// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
325// CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[K]], %[[I]], %[[J]]]
326// CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
327
328// -----
329
330func.func @reduce(%arg0: memref<16x32x64xf32>,
331                  %arg1: memref<16x64xf32>) {
332  linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
333                outs(%arg1 : memref<16x64xf32>) dimensions = [1]
334    (%in: f32, %init: f32) {
335      %0 = arith.addf %in, %init : f32
336      linalg.yield %0 : f32
337    }
338  return
339}
340
341module attributes {transform.with_named_sequence} {
342  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
343    %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
344      : (!transform.any_op) -> !transform.any_op
345    %0 = transform.structured.convert_to_loops %reduce
346      : (!transform.any_op) -> (!transform.any_op)
347    transform.yield
348  }
349}
350// CHECK-LABEL: func.func @reduce(
351// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
352// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32>
353
354// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
355// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
356// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
357// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
358// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
359
360// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
361// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
362// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
363// CHECK:           %[[IN_ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[J]], %[[K]]]
364// CHECK:           %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]]
365// CHECK:           %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]]
366// CHECK:           memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]]
367
368// -----
369
370func.func @broadcast(%input: memref<8x32xf32>,
371                     %init: memref<8x16x32xf32>) {
372  linalg.broadcast
373      ins(%input:memref<8x32xf32>)
374      outs(%init:memref<8x16x32xf32>)
375      dimensions = [1]
376  func.return
377}
378
379module attributes {transform.with_named_sequence} {
380  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
381    %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
382      : (!transform.any_op) -> !transform.any_op
383    %0 = transform.structured.convert_to_loops %broadcast
384      : (!transform.any_op) -> (!transform.any_op)
385    transform.yield
386  }
387}
388// CHECK-LABEL: func.func @broadcast(
389// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
390// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>
391
392// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
393// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
394// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
395// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
396// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
397
398// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
399// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
400// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
401// CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
402// CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
403