xref: /llvm-project/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir (revision bb6d5c220004a5d7e466a669324001285a688918)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file --verify-diagnostics | FileCheck %s
2
3// Check that we produce async copies from the vector.transfer_xxx operations.
4builtin.module {
5  // CHECK-LABEL: @copies_to_asyncs
6  func.func @copies_to_asyncs(%a: memref<1024x1024xf32>) {
7    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
8    %c0 = arith.constant 0 : index
9    %c4 = arith.constant 4 : index
10    %cst_0 = arith.constant 0.000000e+00 : f32
11    // Make sure we emit the bypassL1.
12    // CHECK: %[[CP0:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 4  {bypassL1} :
13    %1 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32>
14    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
15    // CHECK-NOT: nvgpu.device_async_create_group
16
17    // CHECK: %[[CP1:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 1
18    %2 = vector.transfer_read %a[%c0, %c4], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<1xf32>
19    vector.transfer_write %2, %0[%c0, %c4, %c0] {in_bounds = [true]} : vector<1xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
20    // CHECK: %[[G:.*]] = nvgpu.device_async_create_group %[[CP0]], %[[CP1]]
21    // CHECK: nvgpu.device_async_wait %[[G]]
22    return
23  }
24
25  module attributes {transform.with_named_sequence} {
26    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
27      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
28      transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
29      transform.yield
30    }
31  }
32}
33
34// -----
35
36// Check that we properly take `bypass_l1 = false` into account.
37// I.e., we shouldn't be generating bypassL1 attributes.
38builtin.module {
39  // CHECK-LABEL: @copies_to_asyncs_no_mma
40  func.func @copies_to_asyncs_no_mma(%a: memref<1024x1024xf32>) {
41    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
42    %c0 = arith.constant 0 : index
43    %c4 = arith.constant 4 : index
44    %cst_0 = arith.constant 0.000000e+00 : f32
45    // Make sure we don't emit the bypassL1.
46    // CHECK: %[[CP0:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 4 :
47    %1 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32>
48    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
49    // CHECK-NOT: nvgpu.device_async_create_group
50
51    // CHECK: %[[CP1:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 1 :
52    %2 = vector.transfer_read %a[%c0, %c4], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<1xf32>
53    vector.transfer_write %2, %0[%c0, %c4, %c0] {in_bounds = [true]} : vector<1xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
54    // CHECK: %[[G:.*]] = nvgpu.device_async_create_group %[[CP0]], %[[CP1]]
55    // CHECK: nvgpu.device_async_wait %[[G]]
56    return
57  }
58
59  module attributes {transform.with_named_sequence} {
60    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
61      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
62      transform.nvgpu.create_async_groups %top_level_func : (!transform.any_op) -> (!transform.any_op)
63      transform.yield
64    }
65  }
66}
67
68// -----
69
70// Check that pattern works with vector.load/vector.store.
71builtin.module {
72  // CHECK-LABEL: @copies_to_asyncs_load_store
73  func.func @copies_to_asyncs_load_store(%a: memref<1024x1024xf32>) {
74    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
75    %c0 = arith.constant 0 : index
76    %c4 = arith.constant 4 : index
77    %cst_0 = arith.constant 0.000000e+00 : f32
78    // CHECK: %[[CP0:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 4 :
79    %1 = vector.load %a[%c0, %c0] : memref<1024x1024xf32>, vector<4xf32>
80    vector.store %1, %0[%c0, %c0, %c0] : memref<4x32x16xf32, #gpu.address_space<workgroup>>, vector<4xf32>
81    // CHECK-NOT: nvgpu.device_async_create_group
82
83    // CHECK: %[[CP1:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 1 :
84    %2 = vector.load %a[%c0, %c4] : memref<1024x1024xf32>, vector<1xf32>
85    vector.store %2, %0[%c0, %c4, %c0] : memref<4x32x16xf32, #gpu.address_space<workgroup>>, vector<1xf32>
86    // CHECK: %[[G:.*]] = nvgpu.device_async_create_group %[[CP0]], %[[CP1]]
87    // CHECK: nvgpu.device_async_wait %[[G]]
88    return
89  }
90
91  module attributes {transform.with_named_sequence} {
92    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
93      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
94      transform.nvgpu.create_async_groups %top_level_func : (!transform.any_op) -> (!transform.any_op)
95      transform.yield
96    }
97  }
98}
99
100// -----
101
102// Check that pattern skips unaligned and unsupported sizes.
103builtin.module {
104  // CHECK-LABEL: @copies_to_asyncs_load_store
105  func.func @copies_to_asyncs_load_store(%a: memref<1024x1024xf32>, %b: memref<1024x1024xf16>) {
106    %alloc = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
107    %alloc_1 = memref.alloc() : memref<4x32x16xf16, #gpu.address_space<workgroup>>
108    %c0 = arith.constant 0 : index
109    %c4 = arith.constant 4 : index
110    %cst_0 = arith.constant 0.000000e+00 : f32
111
112    // Requires 1-D vector load
113    // CHECK-NOT: nvgpu.device_async_copy
114    //     CHECK: vector.load
115    //     CHECK: vector.store
116    %1 = vector.load %a[%c0, %c4] : memref<1024x1024xf32>, vector<2x2xf32>
117    vector.store %1, %alloc[%c0, %c4, %c0] : memref<4x32x16xf32, #gpu.address_space<workgroup>>, vector<2x2xf32>
118    // CHECK-NOT: nvgpu.device_async_create_group
119
120    // CHECK-NOT: nvgpu.device_async_copy
121    //     CHECK: vector.load
122    //     CHECK: vector.store
123    %2 = vector.load %b[%c0, %c4] : memref<1024x1024xf16>, vector<1xf16>
124    vector.store %2, %alloc_1[%c0, %c4, %c0] : memref<4x32x16xf16, #gpu.address_space<workgroup>>, vector<1xf16>
125    // CHECK-NOT: nvgpu.device_async_create_group
126    return
127  }
128
129  module attributes {transform.with_named_sequence} {
130    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
131      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
132      transform.nvgpu.create_async_groups %top_level_func : (!transform.any_op) -> (!transform.any_op)
133      transform.yield
134    }
135  }
136}
137
138// -----
139
140// vector.transfer_read with a mask.
141builtin.module {
142  // CHECK-LABEL: @read_with_mask(
143  // CHECK-SAME: %{{.*}}: memref<1024x1024xf32>, %[[sz:.*]]: index
144  func.func @read_with_mask(%a: memref<1024x1024xf32>, %sz: index) {
145    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
146    %c0 = arith.constant 0 : index
147    %cst_0 = arith.constant 0.000000e+00 : f32
148    // CHECK: nvgpu.device_async_copy {{.*}}, {{.*}}, 4, %[[sz]] {bypassL1} :
149    %mask = vector.create_mask %sz : vector<4xi1>
150    %1 = vector.transfer_read %a[%c0, %c0], %cst_0, %mask {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32>
151    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
152
153    return
154  }
155
156  module attributes {transform.with_named_sequence} {
157    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
158      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
159      transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
160      transform.yield
161    }
162  }
163}
164
165// -----
166
167// 2D vector.transfer_read with a mask.
168builtin.module {
169  // CHECK-LABEL: @read_2d_with_mask(
170  //  CHECK-SAME:     %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[a:.*]]: memref<1024x1024xf32>
171  func.func @read_2d_with_mask(%sz0: index, %sz1: index, %a: memref<1024x1024xf32>) {
172    // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
173    // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
174    // CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
175    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
176    %c0 = arith.constant 0 : index
177    %cst_0 = arith.constant 0.000000e+00 : f32
178
179    // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
180    // CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
181    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
182
183    // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
184    // CHECK: %[[s1:.*]] = arith.select %[[cmpi1]], %[[sz1]], %[[c0]]
185    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
186
187    // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c2]], %[[sz0]]
188    // CHECK: %[[s2:.*]] = arith.select %[[cmpi2]], %[[sz1]], %[[c0]]
189    // CHECK: nvgpu.device_async_copy %[[a]][%[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
190    %mask = vector.create_mask %sz0, %sz1 : vector<3x4xi1>
191    %1 = vector.transfer_read %a[%c0, %c0], %cst_0, %mask {in_bounds = [true, true]} : memref<1024x1024xf32>, vector<3x4xf32>
192    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
193
194    return
195  }
196
197  module attributes {transform.with_named_sequence} {
198    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
199      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
200      transform.apply_patterns to %top_level_func {
201        transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
202      } : !transform.any_op
203      transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
204      %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
205      transform.apply_cse to %top_level_func_2 : !transform.any_op
206      transform.yield
207    }
208  }
209}
210
211// -----
212
213// 3D vector.transfer_read with a mask.
214builtin.module {
215  // CHECK-LABEL: @read_3d_with_mask(
216  //  CHECK-SAME:     %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index, %[[a:.*]]: memref<1024x1024x1024xf32>
217  func.func @read_3d_with_mask(%sz0: index, %sz1: index, %sz2: index, %a: memref<1024x1024x1024xf32>) {
218    // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
219    // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
220    // CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
221    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
222    %c0 = arith.constant 0 : index
223    %cst_0 = arith.constant 0.000000e+00 : f32
224
225    // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
226    // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c0]], %[[sz1]]
227    // CHECK: %[[cond0:.*]] = arith.andi %[[cmpi1]], %[[cmpi0]]
228    // CHECK: %[[s0:.*]] = arith.select %[[cond0]], %[[sz2]], %[[c0]]
229    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
230
231    // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c1]], %[[sz1]]
232    // CHECK: %[[cond1:.*]] = arith.andi %[[cmpi2]], %[[cmpi0]]
233    // CHECK: %[[s1:.*]] = arith.select %[[cond1]], %[[sz2]], %[[c0]]
234    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
235
236    // CHECK: %[[cmpi3:.*]] = arith.cmpi slt, %[[c2]], %[[sz1]]
237    // CHECK: %[[cond2:.*]] = arith.andi %[[cmpi3]], %[[cmpi0]]
238    // CHECK: %[[s2:.*]] = arith.select %[[cond2]], %[[sz2]], %[[c0]]
239    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
240
241    // CHECK: %[[cmpi4:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
242    // CHECK: %[[cond3:.*]] = arith.andi %[[cmpi1]], %[[cmpi4]]
243    // CHECK: %[[s3:.*]] = arith.select %[[cond3]], %[[sz2]], %[[c0]]
244    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s3]] {bypassL1}
245
246    // CHECK: %[[cond4:.*]] = arith.andi %[[cmpi2]], %[[cmpi4]]
247    // CHECK: %[[s4:.*]] = arith.select %[[cond4]], %[[sz2]], %[[c0]]
248    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s4]] {bypassL1}
249
250    // CHECK: %[[cond5:.*]] = arith.andi %[[cmpi3]], %[[cmpi4]]
251    // CHECK: %[[s5:.*]] = arith.select %[[cond5]], %[[sz2]], %[[c0]]
252    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s5]] {bypassL1}
253    %mask = vector.create_mask %sz0, %sz1, %sz2 : vector<2x3x4xi1>
254    %1 = vector.transfer_read %a[%c0, %c0, %c0], %cst_0, %mask {in_bounds = [true, true, true]} : memref<1024x1024x1024xf32>, vector<2x3x4xf32>
255    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
256
257    return
258  }
259
260  module attributes {transform.with_named_sequence} {
261    transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
262      %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
263      transform.apply_patterns to %top_level_func {
264        transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
265      } : !transform.any_op
266      transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
267      %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
268      transform.apply_cse to %top_level_func_2 : !transform.any_op
269      transform.yield
270    }
271  }
272}
273