xref: /llvm-project/mlir/test/Dialect/Linalg/promote.mlir (revision 6dc8de7a0abc7df8295273694fd9b951ed33708f)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2
3#map1 = affine_map<(d0) -> (d0 + 2)>
4#map2 = affine_map<(d0) -> (d0 + 4)>
5#map3 = affine_map<(d0) -> (d0 + 3)>
6
7func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
8  %c4 = arith.constant 4 : index
9  %c3 = arith.constant 3 : index
10  %c2 = arith.constant 2 : index
11  %c0 = arith.constant 0 : index
12  %c1 = arith.constant 1 : index
13  %3 = memref.view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
14  %4 = memref.view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
15  %5 = memref.view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
16  %6 = memref.dim %3, %c0 : memref<?x?xf32>
17  %7 = memref.dim %3, %c1 : memref<?x?xf32>
18  %8 = memref.dim %4, %c1 : memref<?x?xf32>
19  scf.for %arg4 = %c0 to %6 step %c2 {
20    scf.for %arg5 = %c0 to %8 step %c3 {
21      scf.for %arg6 = %c0 to %7 step %c4 {
22        %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
23        %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
24        %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
25        linalg.matmul
26          ins(%11, %14: memref<?x?xf32, strided<[?, 1], offset: ?>>,
27                        memref<?x?xf32, strided<[?, 1], offset: ?>>)
28         outs(%17: memref<?x?xf32, strided<[?, 1], offset: ?>>)
29      }
30    }
31  }
32  return
33}
34
35// CHECK-LABEL: func @matmul_f32(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
36//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
37//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
38//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
39//       CHECK:         %[[vA:.*]] = memref.subview {{.*}} : memref<?x?xf32>
40//       CHECK:         %[[vB:.*]] = memref.subview {{.*}} : memref<?x?xf32>
41//       CHECK:         %[[vC:.*]] = memref.subview {{.*}} : memref<?x?xf32>
42///
43//       CHECK:         %[[tmpA:.*]] = memref.alloca() : memref<32xi8>
44//       CHECK:         %[[fullA:.*]] = memref.view %[[tmpA]][{{.*}}][{{.*}}] : memref<32xi8> to memref<?x?xf32>
45//       CHECK:         %[[partialA:.*]] = memref.subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
46///
47//       CHECK:         %[[tmpB:.*]] = memref.alloca() : memref<48xi8>
48//       CHECK:         %[[fullB:.*]] = memref.view %[[tmpB]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf32>
49//       CHECK:         %[[partialB:.*]] = memref.subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
50///
51//       CHECK:         %[[tmpC:.*]] = memref.alloca() : memref<24xi8>
52//       CHECK:         %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
53//       CHECK:         %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
54
55//       CHECK:         linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1]>>)
56//       CHECK:         linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1]>>)
57//       CHECK:         linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1]>>)
58//
59//       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
60//
61//       CHECK:         linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1]>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
62//
63//   CHECK-NOT:         memref.dealloc %[[tmpA]] : memref<32xi8>
64//   CHECK-NOT:         memref.dealloc %[[tmpB]] : memref<48xi8>
65//   CHECK-NOT:         memref.dealloc %[[tmpC]] : memref<24xi8>
66
67module attributes {transform.with_named_sequence} {
68  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
69    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
70    %1 = transform.structured.promote %0 { use_alloca } : (!transform.any_op) -> !transform.any_op
71    transform.yield
72  }
73}
74
75// -----
76
77func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
78  %c4 = arith.constant 4 : index
79  %c3 = arith.constant 3 : index
80  %c2 = arith.constant 2 : index
81  %c0 = arith.constant 0 : index
82  %c1 = arith.constant 1 : index
83  %3 = memref.view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf64>
84  %4 = memref.view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf64>
85  %5 = memref.view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf64>
86  %6 = memref.dim %3, %c0 : memref<?x?xf64>
87  %7 = memref.dim %3, %c1 : memref<?x?xf64>
88  %8 = memref.dim %4, %c1 : memref<?x?xf64>
89  scf.for %arg4 = %c0 to %6 step %c2 {
90    scf.for %arg5 = %c0 to %8 step %c3 {
91      scf.for %arg6 = %c0 to %7 step %c4 {
92        %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
93        %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
94        %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
95        linalg.matmul
96          ins(%11, %14: memref<?x?xf64, strided<[?, 1], offset: ?>>,
97                        memref<?x?xf64, strided<[?, 1], offset: ?>>)
98         outs(%17: memref<?x?xf64, strided<[?, 1], offset: ?>>)
99      }
100    }
101  }
102  return
103}
104
105// CHECK-LABEL: func @matmul_f64(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
106//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
107//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
108//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
109//       CHECK:         %[[vA_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64>
110//       CHECK:         %[[vB_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64>
111//       CHECK:         %[[vC_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64>
112///
113//       CHECK:         %[[tmpA_f64:.*]] = memref.alloc() : memref<64xi8>
114//       CHECK:         %[[fullA_f64:.*]] = memref.view %[[tmpA_f64]][{{.*}}][{{.*}}] : memref<64xi8> to memref<?x?xf64>
115//       CHECK:         %[[partialA_f64:.*]] = memref.subview %[[fullA_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1]>>
116///
117//       CHECK:         %[[tmpB_f64:.*]] = memref.alloc() : memref<96xi8>
118//       CHECK:         %[[fullB_f64:.*]] = memref.view %[[tmpB_f64]][{{.*}}][{{.*}}] : memref<96xi8> to memref<?x?xf64>
119//       CHECK:         %[[partialB_f64:.*]] = memref.subview %[[fullB_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1]>>
120///
121//       CHECK:         %[[tmpC_f64:.*]] = memref.alloc() : memref<48xi8>
122//       CHECK:         %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
123//       CHECK:         %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1]>>
124
125//       CHECK:         linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1]>>)
126//       CHECK:         linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1]>>)
127//       CHECK:         linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1]>>)
128//
129//       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
130//
131//       CHECK:         linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1]>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
132//
133//       CHECK:         memref.dealloc %[[tmpA_f64]] : memref<64xi8>
134//       CHECK:         memref.dealloc %[[tmpB_f64]] : memref<96xi8>
135//       CHECK:         memref.dealloc %[[tmpC_f64]] : memref<48xi8>
136
137module attributes {transform.with_named_sequence} {
138  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
139    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
140    %1 = transform.structured.promote %0 : (!transform.any_op) -> !transform.any_op
141    transform.yield
142  }
143}
144
145// -----
146func.func @gemm_shared(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
147{
148   linalg.matmul ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
149               outs(%c: memref<?x?xf32>)
150   return
151}
152
153// CHECK: func @gemm_shared
154// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
155// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
156// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
157// CHECK: %[[alloc_A:.*]] = memref.alloc() : memref<16x16xf32, #gpu.address_space<workgroup>>
158// CHECK: %[[alloc_B:.*]] = memref.alloc() : memref<16x16xf32, #gpu.address_space<workgroup>>
159// CHECK-DAG: %[[C16:.*]] = arith.constant 16
160// CHECK-DAG: %[[C0:.*]] = arith.constant 0
161// CHECK-DAG: %[[C1:.*]] = arith.constant 1
162// CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
163// CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
164// CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
165// CHECK:         %[[subview_A:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
166// CHECK:         %[[subview_B:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
167// CHECK:         %[[subview_C:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
168
169// CHECK:         %[[shared_A:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
170// CHECK:         %[[shared_B:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
171
172// CHECK-NEXT:    gpu.barrier
173// CHECK-NEXT:    memref.copy %[[subview_A]], %[[shared_A]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
174// CHECK-NEXT:    gpu.barrier
175
176// CHECK-NEXT:    gpu.barrier
177// CHECK-NEXT:    memref.copy %[[subview_B]], %[[shared_B]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
178// CHECK-NEXT:    gpu.barrier
179
180// CHECK:         linalg.matmul ins(%[[shared_A]], %[[shared_B]]{{.*}} outs(%[[subview_C]]
181
182
183module attributes {transform.with_named_sequence} {
184  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
185    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
186    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [16, 16, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
187    %2 = transform.structured.promote %1 { operands_to_promote = [0, 1], mapping = [#gpu.memory_space<workgroup>] } : (!transform.any_op) -> !transform.any_op
188    transform.yield
189  }
190}
191
192
193// -----
194
195func.func @gemm_private(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
196{
197   linalg.matmul ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
198               outs(%c: memref<?x?xf32>)
199   return
200}
201
202// CHECK: func @gemm_private
203// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
204// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
205// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
206// CHECK: %[[alloc_A:.*]] = memref.alloca() : memref<16x16xf32, #gpu.address_space<private>>
207// CHECK: %[[alloc_B:.*]] = memref.alloca() : memref<16x16xf32, #gpu.address_space<private>>
208// CHECK-DAG: %[[C16:.*]] = arith.constant 16
209// CHECK-DAG: %[[C0:.*]] = arith.constant 0
210// CHECK-DAG: %[[C1:.*]] = arith.constant 1
211// CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
212// CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
213// CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
214// CHECK:         %[[subview_A:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
215// CHECK:         %[[subview_B:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
216// CHECK:         %[[subview_C:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
217
218// CHECK:         %[[private_A:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<private>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
219// CHECK:         %[[private_B:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<private>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
220
221// CHECK-NEXT:    memref.copy %[[subview_A]], %[[private_A]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
222// CHECK-NEXT:    memref.copy %[[subview_B]], %[[private_B]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
223
224// CHECK:         linalg.matmul ins(%[[private_A]], %[[private_B]]{{.*}} outs(%[[subview_C]]
225
226
227module attributes {transform.with_named_sequence} {
228  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
229    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
230    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [16, 16, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
231    %2 = transform.structured.promote %1 { operands_to_promote = [0, 1], mapping = [#gpu.memory_space<private>] } : (!transform.any_op) -> !transform.any_op
232    transform.yield
233  }
234}
235
236
237// -----
238
239#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
240#map7 = affine_map<(d0, d1, d2) -> (d1, d2)>
241#map8 = affine_map<(d0, d1, d2) -> (d0, d1)>
242
243// CHECK: promote_rank_reducing_subviews(%[[arg0:.+]]: memref<{{.*}}>, %[[arg1:.+]]: memref<{{.*}}>, %[[arg2:.+]]: memref<{{.*}}>, %[[lb1:.+]]: index, %[[lb2:.+]]: index, %[[lb3:.+]]: index, %[[lb4:.+]]: index, %[[lb5:.+]]: index, %[[lb6:.+]]: index, %[[ub1:.+]]: index, %[[ub2:.+]]: index
244func.func @promote_rank_reducing_subviews(%arg0:  memref<?x?x?x64xf32, strided<[?, ?, ?, ?], offset: ?>>, %arg1: memref<128x3x3x64xf32, strided<[?, ?, ?, ?], offset: ?>>, %arg2: memref<?x?x?x128xf32>,
245                                          %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %ub1: index, %ub2: index) {
246  %13 = memref.subview %arg0[%arg3, 0, %arg4, %arg8] [1, 1, %ub1, 32] [1, 1, 1, 1] : memref<?x?x?x64xf32, strided<[?, ?, ?, ?], offset: ?>> to memref<?x32xf32, strided<[?, ?], offset: ?>>
247  %14 = memref.subview %arg1[0, %arg6, %arg7, %arg8] [128, 1, 1, 32] [1, 1, 1, 1] : memref<128x3x3x64xf32, strided<[?, ?, ?, ?], offset: ?>> to memref<128x32xf32, strided<[?, ?], offset: ?>>
248  %9 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [1, 1, %ub2, 128] [1, 1, 1, 1] : memref<?x?x?x128xf32> to memref<?x128xf32, strided<[128, 1], offset: ?>>
249
250  // CHECK: %[[a_alloc:.+]] = memref.alloc
251  // CHECK: %[[a_view:.+]] = memref.view %[[a_alloc]]{{.*}}
252  // CHECK: %[[a_pro_subview:.+]] = memref.subview %[[a_view]][0, 0] [%[[ub1]], {{.+}}] [1, 1]
253
254  // CHECK: memref.alloc
255  // CHECK: %[[b_view:.+]] = memref.view
256  // CHECK: %[[b_pro_subview:.+]] = memref.subview %[[b_view]]
257
258  // CHECK: memref.alloc
259  // CHECK: %[[c_view:.+]] = memref.view
260  // CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
261
262  // CHECK-COUNT-3: linalg.copy
263  // CHECK: linalg.generic
264  // CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
265  // CHECK-SAME: outs(%[[c_pro_subview]]
266
267  linalg.generic {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : memref<?x32xf32, strided<[?, ?], offset: ?>>, memref<128x32xf32, strided<[?, ?], offset: ?>>) outs(%9 : memref<?x128xf32, strided<[128, 1], offset: ?>>) {
268  ^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
269    %15 = arith.mulf %arg9, %arg10 : f32
270    %16 = arith.addf %arg11, %15 : f32
271    linalg.yield %16 : f32
272  }
273
274  return
275}
276
277module attributes {transform.with_named_sequence} {
278  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
279    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
280    %1 = transform.structured.promote %0 : (!transform.any_op) -> !transform.any_op
281    transform.yield
282  }
283}
284
285// -----
286
287#map = affine_map<(d0, d1) -> (d0, d1)>
288
289  // CHECK-LABEL:   func.func @linalg_generic_update_all_function_inputs_outputs(
290  // CHECK-SAME:                                                                 %[[VAL_0:.*]]: memref<3x4xf32, 1>,
291  // CHECK-SAME:                                                                 %[[VAL_1:.*]]: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> {
292func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf32, 1>, %arg1: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> {
293  // CHECK:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1>
294  // CHECK:           %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
295  // CHECK:           %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
296  // CHECK:           %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
297
298  %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1>
299  %subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
300  %subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
301  %subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
302
303  // CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
304  // CHECK:           %[[VAL_7:.*]] = arith.constant 4 : index
305  // CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
306  // CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
307  // CHECK:           %[[VAL_10:.*]] = arith.constant 3 : index
308  // CHECK:           %[[VAL_11:.*]] = arith.constant 1 : index
309  // CHECK:           %[[VAL_12:.*]] = arith.constant 4 : index
310  // CHECK:           %[[VAL_13:.*]] = arith.constant 0 : index
311  // CHECK:           %[[VAL_14:.*]] = arith.constant 4 : index
312  // CHECK:           %[[VAL_15:.*]] = arith.constant 3 : index
313  // CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
314  // CHECK:           %[[VAL_17:.*]] = arith.constant 3 : index
315  // CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
316  // CHECK:           %[[VAL_19:.*]] = arith.constant 1 : index
317  // CHECK:           %[[VAL_20:.*]] = arith.constant 4 : index
318  // CHECK:           %[[VAL_21:.*]] = arith.constant 12 : index
319  // CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
320  // CHECK:           %[[VAL_23:.*]] = memref.view %[[VAL_22]]{{\[}}%[[VAL_18]]]{{\[}}%[[VAL_12]], %[[VAL_15]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
321  // CHECK:           %[[VAL_24:.*]] = memref.subview %[[VAL_23]][0, 0] {{\[}}%[[VAL_14]], %[[VAL_17]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>
322  // CHECK:           %[[VAL_25:.*]] = arith.constant 0 : index
323  // CHECK:           %[[VAL_26:.*]] = arith.constant 4 : index
324  // CHECK:           %[[VAL_27:.*]] = arith.constant 1 : index
325  // CHECK:           %[[VAL_28:.*]] = arith.constant 0 : index
326  // CHECK:           %[[VAL_29:.*]] = arith.constant 3 : index
327  // CHECK:           %[[VAL_30:.*]] = arith.constant 1 : index
328  // CHECK:           %[[VAL_31:.*]] = arith.constant 4 : index
329  // CHECK:           %[[VAL_32:.*]] = arith.constant 0 : index
330  // CHECK:           %[[VAL_33:.*]] = arith.constant 4 : index
331  // CHECK:           %[[VAL_34:.*]] = arith.constant 3 : index
332  // CHECK:           %[[VAL_35:.*]] = arith.constant 1 : index
333  // CHECK:           %[[VAL_36:.*]] = arith.constant 3 : index
334  // CHECK:           %[[VAL_37:.*]] = arith.constant 0 : index
335  // CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index
336  // CHECK:           %[[VAL_39:.*]] = arith.constant 4 : index
337  // CHECK:           %[[VAL_40:.*]] = arith.constant 12 : index
338  // CHECK:           %[[VAL_41:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
339  // CHECK:           %[[VAL_42:.*]] = memref.view %[[VAL_41]]{{\[}}%[[VAL_37]]]{{\[}}%[[VAL_31]], %[[VAL_34]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
340  // CHECK:           %[[VAL_43:.*]] = memref.subview %[[VAL_42]][0, 0] {{\[}}%[[VAL_33]], %[[VAL_36]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>
341  // CHECK:           %[[VAL_44:.*]] = arith.constant 0 : index
342  // CHECK:           %[[VAL_45:.*]] = arith.constant 4 : index
343  // CHECK:           %[[VAL_46:.*]] = arith.constant 1 : index
344  // CHECK:           %[[VAL_47:.*]] = arith.constant 0 : index
345  // CHECK:           %[[VAL_48:.*]] = arith.constant 3 : index
346  // CHECK:           %[[VAL_49:.*]] = arith.constant 1 : index
347  // CHECK:           %[[VAL_50:.*]] = arith.constant 4 : index
348  // CHECK:           %[[VAL_51:.*]] = arith.constant 0 : index
349  // CHECK:           %[[VAL_52:.*]] = arith.constant 4 : index
350  // CHECK:           %[[VAL_53:.*]] = arith.constant 3 : index
351  // CHECK:           %[[VAL_54:.*]] = arith.constant 1 : index
352  // CHECK:           %[[VAL_55:.*]] = arith.constant 3 : index
353  // CHECK:           %[[VAL_56:.*]] = arith.constant 0 : index
354  // CHECK:           %[[VAL_57:.*]] = arith.constant 1 : index
355  // CHECK:           %[[VAL_58:.*]] = arith.constant 4 : index
356  // CHECK:           %[[VAL_59:.*]] = arith.constant 12 : index
357  // CHECK:           %[[VAL_60:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
358  // CHECK:           %[[VAL_61:.*]] = memref.view %[[VAL_60]]{{\[}}%[[VAL_56]]]{{\[}}%[[VAL_50]], %[[VAL_53]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
359  // CHECK:           %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>
360// CHECK:           linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>)
361// CHECK:           linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>)
362  // CHECK:           linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>) {
363  // CHECK:           ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
364  // CHECK:             %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32
365  // CHECK:             linalg.yield %[[VAL_66]] : f32
366  // CHECK:           }
367
368
369  linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%subview, %subview_0 : memref<4x3xf32, strided<[4, 1]>, 1>, memref<4x3xf32, strided<[4, 1]>, 1>) outs(%subview_1 : memref<4x3xf32, strided<[4, 1]>, 1>) {
370  ^bb0(%in: f32, %in_1: f32, %out: f32):
371    %1 = arith.addf %in, %in_1 : f32
372    linalg.yield %1 : f32
373  }
374
375  // CHECK:           linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
376  // CHECK:           memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
377  // CHECK:           memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
378  // CHECK:           memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
379  // CHECK:           return %[[VAL_2]] : memref<3x4xf32, 1>
380  // CHECK:         }
381
382  return %alloc : memref<3x4xf32, 1>
383}
384
385
386module attributes {transform.with_named_sequence} {
387  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
388    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
389    %1 = transform.structured.promote %0 { memory_space = #gpu.address_space<workgroup> } : (!transform.any_op) -> !transform.any_op
390    transform.yield
391  }
392}
393