xref: /llvm-project/mlir/test/Dialect/Mesh/ops.mlir (revision 79eb406a67fe08458548289da72cda18248a9313)
1// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2
3// CHECK: mesh.mesh @mesh0
4mesh.mesh @mesh0(shape = 2x2x4)
5
6// CHECK: mesh.mesh @mesh1(shape = 4x?)
7mesh.mesh @mesh1(shape = 4x?)
8
9// CHECK: mesh.mesh @mesh2(shape = ?x4)
10mesh.mesh @mesh2(shape = ?x4)
11
12// CHECK: mesh.mesh @mesh3(shape = ?x?)
13mesh.mesh @mesh3(shape = ?x?)
14
15mesh.mesh @mesh4(shape = 3)
16
17// CHECK: mesh.mesh @mesh5(shape = ?)
18mesh.mesh @mesh5(shape = ?)
19
20// CHECK-LABEL: func @mesh_shard_op_fully_replicated
21// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
22func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
23  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
24  %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
25  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
26  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
27  return %0 : tensor<4x8xf32>
28}
29
30// CHECK-LABEL: func @mesh_shard_op_1st_dim
31// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
32func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
33  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
34  %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
35
36  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
37  return %0 : tensor<4x8xf32>
38}
39
40// CHECK-LABEL: func @mesh_shard_op_2nd_dim
41// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
42func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
43  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
44  %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
45  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
46  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
47  return %0 : tensor<4x8xf32>
48}
49
50// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
51func.func @mesh_shard_op_1st_and_3rd_dim(
52    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
53    %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
54  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding
55  %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding
56  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
57  %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
58  return %0 : tensor<4x8x16xf32>
59}
60
61// CHECK-LABEL: func @mesh_shard_op_partial_max
62func.func @mesh_shard_op_partial_max(
63    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
64    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
65  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = max [1] : !mesh.sharding
66  %s = mesh.sharding @mesh3 split_axes = [[0]] partial = max[1] : !mesh.sharding
67  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
68  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
69  return %0 : tensor<4x8xf32>
70}
71
72// CHECK-LABEL: func @mesh_shard_op_partial_min
73func.func @mesh_shard_op_partial_min(
74    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
75    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
76  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = min [1] : !mesh.sharding
77  %s = mesh.sharding @mesh3 split_axes = [[0]] partial = min[1] : !mesh.sharding
78  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
79  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
80  return %0 : tensor<4x8xf32>
81}
82
83// CHECK-LABEL: func @mesh_shard_op_partial_generic
84func.func @mesh_shard_op_partial_generic(
85    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
86    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
87  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = generic [1] : !mesh.sharding
88  %s = mesh.sharding @mesh3 split_axes = [[0]] partial = generic[1] : !mesh.sharding
89  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
90  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
91  return %0 : tensor<4x8xf32>
92}
93
94// CHECK-LABEL: func @mesh_shard_op_partial_sum
95func.func @mesh_shard_op_partial_sum(
96    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
97    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
98  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1] : !mesh.sharding
99  %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1] : !mesh.sharding
100  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
101  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
102  return %0 : tensor<4x8xf32>
103}
104
105// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes
106func.func @mesh_shard_op_partial_sum_multi_axes(
107    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
108    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
109  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1, 2] : !mesh.sharding
110  %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1, 2] : !mesh.sharding
111  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
112  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
113  return %0 : tensor<4x8xf32>
114}
115
116// CHECK-LABEL: func @mesh_shard_op_two_users
117// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
118func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
119                                  (tensor<4x8xf32>, tensor<4x8xf32>) {
120  // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
121  %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
122  %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
123  // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding
124  %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
125  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
126  // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding
127  %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
128  %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
129  return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
130}
131
132// CHECK-LABEL: func @mesh_shard_halo_sizes
133func.func @mesh_shard_halo_sizes() -> () {
134  // CHECK: %[[C3:.*]] = arith.constant 3 : i64
135  %c3 = arith.constant 3 : i64
136  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
137  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding
138  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
139  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding
140  return
141}
142
143// CHECK-LABEL: func @mesh_shard_dims_sizes
144func.func @mesh_shard_dims_sizes() -> () {
145  // CHECK: %[[C3:.*]] = arith.constant 3 : i64
146  %c3 = arith.constant 3 : i64
147  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
148  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
149  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
150  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
151  return
152}
153
154// CHECK-LABEL: func @mesh_shard_shape
155func.func @mesh_shard_shape() {
156  // CHECK: %[[C3:.*]] = arith.constant 3 : index
157  %c3 = arith.constant 3 : index
158  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
159  %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
160  // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
161  %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
162  // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
163  %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
164  return
165}
166
167// CHECK-LABEL: func @mesh_shape
168func.func @mesh_shape() -> (index, index) {
169  // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
170  %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
171  // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
172  return %0#0, %0#1 : index, index
173}
174
175// CHECK-LABEL: func @mesh_shape_default_axes
176func.func @mesh_shape_default_axes() -> (index, index, index) {
177  // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
178  %0:3 = mesh.mesh_shape @mesh0 : index, index, index
179  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
180  return %0#0, %0#1, %0#2 : index, index, index
181}
182
183// CHECK-LABEL: func @mesh_shape_empty_axes
184func.func @mesh_shape_empty_axes() -> (index, index, index) {
185  // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
186  %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
187  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
188  return %0#0, %0#1, %0#2 : index, index, index
189}
190
191// CHECK-LABEL: func @process_multi_index
192func.func @process_multi_index() -> (index, index) {
193  // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
194  %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
195  // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
196  return %0#0, %0#1 : index, index
197}
198
199// CHECK-LABEL: func @process_multi_index_default_axes
200func.func @process_multi_index_default_axes() -> (index, index, index) {
201  // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
202  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
203  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
204  return %0#0, %0#1, %0#2 : index, index, index
205}
206
207// CHECK-LABEL: func @process_multi_index_empty_axes
208func.func @process_multi_index_empty_axes() -> (index, index, index) {
209  // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
210  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
211  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
212  return %0#0, %0#1, %0#2 : index, index, index
213}
214
215// CHECK-LABEL: func @process_linear_index
216func.func @process_linear_index() -> index {
217  // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
218  %0 = mesh.process_linear_index on @mesh0 : index
219  // CHECK: return %[[RES]] : index
220  return %0 : index
221}
222
223// CHECK-LABEL: func @all_reduce
224func.func @all_reduce(
225    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
226    %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
227  // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
228  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
229  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
230    : tensor<3x4xf32> -> tensor<3x4xf64>
231  return %0 : tensor<3x4xf64>
232}
233
234// CHECK-LABEL: func @all_gather
235func.func @all_gather(
236    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
237    %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
238  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
239  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
240  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
241    : tensor<3x4xf32> -> tensor<3x16xf32>
242  return %0 : tensor<3x16xf32>
243}
244
245// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
246func.func @all_gather_dynamic_dims_in_tensor(
247    // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
248    %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
249  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
250  // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
251  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
252    : tensor<?x?xf32> -> tensor<?x?xf32>
253  return %0 : tensor<?x?xf32>
254}
255
256// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
257func.func @all_gather_dynamic_dims_in_mesh(
258    // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
259    %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
260  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
261  // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
262  %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
263    : tensor<5x6xf32> -> tensor<5x?xf32>
264  return %0 : tensor<5x?xf32>
265}
266
267// CHECK-LABEL: func @all_slice_static_dimensions
268func.func @all_slice_static_dimensions(
269    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
270    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
271  // CHECK-NEXT: mesh.all_slice %[[ARG]]
272  // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
273  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
274  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
275    : tensor<3x4xf32> -> tensor<3x1xf32>
276  return %0 : tensor<3x1xf32>
277}
278
279// CHECK-LABEL: func @all_slice_dynamic_dimensions
280func.func @all_slice_dynamic_dimensions(
281    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
282    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
283  // CHECK-NEXT: mesh.all_slice %[[ARG]]
284  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
285  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
286  %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
287    : tensor<?xf32> -> tensor<?xf32>
288  return %0 : tensor<?xf32>
289}
290
291// CHECK-LABEL: func @all_to_all
292func.func @all_to_all(
293    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
294    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
295  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
296  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
297  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
298  %0 = mesh.all_to_all %arg0 on @mesh4
299    split_axis = 1 concat_axis = 0
300    : tensor<3x6xi8> -> tensor<3x6xi8>
301  return %0 : tensor<3x6xi8>
302}
303
304// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
305func.func @all_to_all_dynamic_dims_in_result(
306    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
307    %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
308  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
309  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
310  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
311  %0 = mesh.all_to_all %arg0 on @mesh4
312    split_axis = 1 concat_axis = 0
313    : tensor<3x6xi8> -> tensor<3x?xi8>
314  return %0 : tensor<3x?xi8>
315}
316
317// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
318func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
319    // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
320    %arg0 : tensor<3xi8>) -> tensor<3xi8> {
321  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
322  // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
323  // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
324  %0 = mesh.all_to_all %arg0 on @mesh4
325    split_axis = 0 concat_axis = 0
326    : tensor<3xi8> -> tensor<3xi8>
327  return %0 : tensor<3xi8>
328}
329
330// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
331func.func @all_to_all_non_divisible_split_axis_size(
332    // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
333    %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
334  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
335  // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
336  // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
337  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
338    split_axis = 0 concat_axis = 1
339    : tensor<2x3xi8> -> tensor<?x12xi8>
340  return %0 : tensor<?x12xi8>
341}
342
343// CHECK-LABEL: func @broadcast_static_root
344func.func @broadcast_static_root(
345    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
346    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
347  // CHECK-NEXT: mesh.broadcast %[[ARG]]
348  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
349  // CHECK-SAME: root = [0, 1]
350  // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
351  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
352    root = [0, 1]
353    : (tensor<3x6xi8>) -> tensor<3x6xi8>
354  return %0 : tensor<3x6xi8>
355}
356
357// CHECK-LABEL: func @broadcast_dynamic_root
358func.func @broadcast_dynamic_root(
359    // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
360    %arg0 : tensor<3x6xi8>,
361    // CHECK-SAME: %[[ARG1:.*]]: index
362    %arg1 : index
363    ) -> tensor<3x6xi8> {
364  // CHECK-NEXT: mesh.broadcast %[[ARG0]]
365  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
366  // CHECK-SAME: root = [1, %[[ARG1]]]
367  // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
368  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
369    root = [1, %arg1]
370    : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
371  return %0 : tensor<3x6xi8>
372}
373
374// CHECK-LABEL: func @gather_static_root
375func.func @gather_static_root(
376    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
377    %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
378  // CHECK-NEXT: mesh.gather %[[ARG]]
379  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
380  // CHECK-SAME: gather_axis = 0
381  // CHECK-SAME: root = [0, 1]
382  // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
383  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
384    gather_axis = 0
385    root = [0, 1]
386    : (tensor<3x6xi8>) -> tensor<24x6xi8>
387  return %0 : tensor<24x6xi8>
388}
389
390// CHECK-LABEL: func @gather_dynamic_root
391func.func @gather_dynamic_root(
392    // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
393    %arg0 : tensor<3x6xi8>,
394    // CHECK-SAME: %[[ARG1:.*]]: index
395    %arg1 : index
396    ) -> tensor<24x6xi8> {
397  // CHECK-NEXT: mesh.gather %[[ARG0]]
398  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
399  // CHECK-SAME: gather_axis = 0
400  // CHECK-SAME: root = [1, %[[ARG1]]]
401  // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
402  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
403    gather_axis = 0
404    root = [1, %arg1]
405    : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
406  return %0 : tensor<24x6xi8>
407}
408
409// CHECK-LABEL: func @receive_static_source
410func.func @receive_static_source(
411    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
412    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
413  // CHECK-NEXT: mesh.recv %[[ARG]]
414  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
415  // CHECK-SAME: source = [0, 1]
416  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
417  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
418    source = [0, 1]
419    : (tensor<2xi8>) -> tensor<2xi8>
420  return %0 : tensor<2xi8>
421}
422
423// CHECK-LABEL: func @receive_dynamic_source
424func.func @receive_dynamic_source(
425    // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
426    %arg0 : tensor<2xi8>,
427    // CHECK-SAME: %[[ARG1:.*]]: index
428    %arg1 : index
429    ) -> tensor<2xi8> {
430  // CHECK-NEXT: mesh.recv %[[ARG0]]
431  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
432  // CHECK-SAME: source = [1, %[[ARG1]]]
433  // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
434  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
435    source = [1, %arg1]
436    : (tensor<2xi8>, index) -> tensor<2xi8>
437  return %0 : tensor<2xi8>
438}
439
440// CHECK-LABEL: func @receive_no_source
441func.func @receive_no_source(
442    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
443    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
444  // CHECK-NEXT: mesh.recv %[[ARG]]
445  // CHECK-NOT: source
446  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
447    : (tensor<2xi8>) -> tensor<2xi8>
448  return %0 : tensor<2xi8>
449}
450
451// CHECK-LABEL: func @reduce_static_root
452func.func @reduce_static_root(
453    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
454    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
455  // CHECK-NEXT: mesh.reduce %[[ARG]]
456  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
457  // CHECK-SAME: root = [0, 1]
458  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
459  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
460    root = [0, 1]
461    : (tensor<2xi8>) -> tensor<2xi8>
462  return %0 : tensor<2xi8>
463}
464
465// CHECK-LABEL: func @reduce_dynamic_root
466func.func @reduce_dynamic_root(
467    // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
468    %arg0 : tensor<2xi8>,
469    // CHECK-SAME: %[[ARG1:.*]]: index
470    %arg1 : index
471    ) -> tensor<2xi8> {
472  // CHECK-NEXT: mesh.reduce %[[ARG0]]
473  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
474  // CHECK-SAME: root = [1, %[[ARG1]]]
475  // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
476  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
477    root = [1, %arg1]
478    : (tensor<2xi8>, index) -> tensor<2xi8>
479  return %0 : tensor<2xi8>
480}
481
482// CHECK-LABEL: func @reduce_different_return_element_type
483func.func @reduce_different_return_element_type(
484    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
485    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
486  // CHECK-NEXT: mesh.reduce %[[ARG]]
487  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
488  // CHECK-SAME: root = [0, 1]
489  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
490  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
491    root = [0, 1]
492    : (tensor<2xi8>) -> tensor<2xi16>
493  return %0 : tensor<2xi16>
494}
495
496// CHECK-LABEL: func @reduce_scatter_static_dimensions
497func.func @reduce_scatter_static_dimensions(
498    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
499    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
500  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
501  // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
502  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
503  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
504    reduction = max scatter_axis = 1
505    : tensor<3x4xf32> -> tensor<3x1xf64>
506  return %0 : tensor<3x1xf64>
507}
508
509// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
510func.func @reduce_scatter_dynamic_dimensions(
511    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
512    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
513  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
514  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
515  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
516  %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
517    : tensor<?xf32> -> tensor<?xf64>
518  return %0 : tensor<?xf64>
519}
520
521// CHECK-LABEL: func @scatter_static_dimensions
522func.func @scatter_static_dimensions(
523    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
524    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
525  // CHECK-NEXT: mesh.scatter %[[ARG]]
526  // CHECK-SAME: on @mesh0 mesh_axes = [2]
527  // CHECK-SAME: scatter_axis = 1 root = [1]
528  // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
529  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
530    scatter_axis = 1 root = [1]
531    : (tensor<3x4xf32>) -> tensor<3x1xf32>
532  return %0 : tensor<3x1xf32>
533}
534
535// CHECK-LABEL: func @scatter_dynamic_dimensions
536func.func @scatter_dynamic_dimensions(
537    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
538    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
539  // CHECK-NEXT: mesh.scatter %[[ARG]]
540  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
541  // CHECK-SAME: scatter_axis = 0 root = [1, 2]
542  // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
543  %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
544    scatter_axis = 0 root = [1, 2]
545    : (tensor<?xf32>) -> tensor<?xf32>
546  return %0 : tensor<?xf32>
547}
548
549// CHECK-LABEL: func @scatter_dynamic_root
550func.func @scatter_dynamic_root(
551    // CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8>
552    %arg0 : tensor<8xi8>,
553    // CHECK-SAME: %[[ARG1:.*]]: index
554    %arg1 : index
555    ) -> tensor<1xi8> {
556  // CHECK-NEXT: mesh.scatter %[[ARG0]]
557  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
558  // CHECK-SAME: scatter_axis = 0
559  // CHECK-SAME: root = [1, %[[ARG1]]]
560  // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
561  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
562    scatter_axis = 0
563    root = [1, %arg1]
564    : (tensor<8xi8>, index) -> tensor<1xi8>
565  return %0 : tensor<1xi8>
566}
567
568// CHECK-LABEL: func @send_static_destination
569func.func @send_static_destination(
570    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
571    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
572  // CHECK-NEXT: mesh.send %[[ARG]]
573  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
574  // CHECK-SAME: destination = [0, 1]
575  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
576  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
577    destination = [0, 1]
578    : (tensor<2xi8>) -> tensor<2xi8>
579  return %0 : tensor<2xi8>
580}
581
582// CHECK-LABEL: func @send_dynamic_destination
583func.func @send_dynamic_destination(
584    // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
585    %arg0 : tensor<2xi8>,
586    // CHECK-SAME: %[[ARG1:.*]]: index
587    %arg1 : index
588    ) -> tensor<2xi8> {
589  // CHECK-NEXT: mesh.send %[[ARG0]]
590  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
591  // CHECK-SAME: destination = [1, %[[ARG1]]]
592  // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
593  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
594    destination = [1, %arg1]
595    : (tensor<2xi8>, index) -> tensor<2xi8>
596  return %0 : tensor<2xi8>
597}
598
599// CHECK-LABEL: func @shift
600func.func @shift(
601    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
602    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
603  // CHECK-NEXT: mesh.shift %[[ARG]]
604  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
605  // CHECK-SAME: shift_axis = 2 offset = -2 rotate
606  // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
607  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
608    shift_axis = 2 offset = -2 rotate
609    : tensor<2xi8> -> tensor<2xi8>
610  return %0 : tensor<2xi8>
611}
612
613// CHECK-LABEL: func @update_halo
614func.func @update_halo(
615    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
616    %arg0 : memref<12x12xi8>) {
617  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
618  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
619  // CHECK-SAME: split_axes = {{\[\[}}0]]
620  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
621  %c2 = arith.constant 2 : i64
622  %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
623    halo_sizes = [2, %c2] : memref<12x12xi8>
624  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
625  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
626  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
627  %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
628    halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
629  return
630}
631