xref: /llvm-project/mlir/test/Dialect/Mesh/invalid.mlir (revision ffc7feadece139c88f0e6930f16bfa9293747adc)
1// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2
3// expected-error@+1 {{rank of mesh is expected to be a positive integer}}
4mesh.mesh @mesh0(shape = [])
5
6// -----
7
8// expected-error@+1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
9mesh.mesh @mesh0(shape = -1)
10
11// -----
12
13mesh.mesh @mesh0(shape = 2x4)
14
15func.func @mesh_axis_duplicated_different_subarray(
16    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
17  // expected-error@+1 {{mesh axis duplicated}}
18  %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding
19  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
20  return %0 : tensor<4x8xf32>
21}
22
23// -----
24
25mesh.mesh @mesh0(shape = 2x4)
26
27func.func @mesh_axis_duplicated_same_subarray(
28    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
29  // expected-error@+1 {{mesh axis duplicated}}
30  %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding
31  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
32  return %0 : tensor<4x8xf32>
33}
34
35// -----
36
37mesh.mesh @mesh0(shape = 2x4)
38
39func.func @mesh_axis_duplicated_bewteen_split_and_partial(
40    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
41  // expected-error@+1 {{mesh axis duplicated}}
42  %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[0] : !mesh.sharding
43  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
44  return %0 : tensor<4x8xf32>
45}
46
47// -----
48
49mesh.mesh @mesh0(shape = 2x4)
50
51func.func @mesh_axis_negtive_in_split_part(
52    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
53  // expected-error@+1 {{mesh axis is expected to be non-negative}}
54  %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding
55  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
56  return %0 : tensor<4x8xf32>
57}
58
59// -----
60
61mesh.mesh @mesh0(shape = 2x4)
62
63func.func @mesh_axis_negtive_in_partial(
64    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
65  // expected-error@+1 {{mesh axis is expected to be non-negative}}
66  %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[-1] : !mesh.sharding
67  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
68  return %0 : tensor<4x8xf32>
69}
70
71// -----
72
73func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
74  // expected-error@+1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
75  %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding
76  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
77  return
78}
79
80// -----
81
82func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
83  // expected-error@+1 {{halo sizes must be specified for all split axes}}
84  %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding
85  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
86  return
87}
88
89// -----
90
91func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
92  // expected-error@+1 {{halo sizes and shard offsets are mutually exclusive}}
93  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
94  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
95  return
96}
97
98// -----
99
100mesh.mesh @mesh_dyn(shape = ?x?)
101func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
102  // expected-error@+1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
103  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
104  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
105  return
106}
107
108// -----
109
110mesh.mesh @mesh0(shape = 2x4)
111func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
112  // expected-error@+1 {{sharded dims offsets has wrong size}}
113  %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
114  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
115  return
116}
117
118// -----
119
120mesh.mesh @mesh0(shape = 4)
121func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
122  // expected-error@+1 {{sharded dims offsets must be non-decreasing}}
123  %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
124  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
125  return
126}
127
128// -----
129
130mesh.mesh @mesh0(shape = 2x4)
131
132func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {
133  // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
134  %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index
135  return %0#0, %0#1 : index, index
136}
137
138// -----
139
140mesh.mesh @mesh0(shape = 1x2x3)
141
142func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) {
143  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
144  %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index
145  return %0#0, %0#1, %0#2 : index, index, index
146}
147
148// -----
149
150mesh.mesh @mesh0(shape = 2x4)
151
152func.func @mesh_shape_wrong_number_of_results() -> (index, index) {
153  // expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
154  %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index
155  return %0#0, %0#1 : index, index
156}
157
158// -----
159
160mesh.mesh @mesh0(shape = 1x2x3)
161
162func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
163  // expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
164  %0:2 = mesh.mesh_shape @mesh0 : index, index
165  return %0#0, %0#1 : index, index
166}
167
168// -----
169
170func.func @mesh_shape_invalid_mesh_name() -> (index) {
171  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
172  %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index
173  return %0#0 : index
174}
175
176// -----
177
178mesh.mesh @mesh0(shape = 2x4)
179
180func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
181  // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
182  %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
183  return %0#0, %0#1 : index, index
184}
185
186// -----
187
188mesh.mesh @mesh0(shape = 1x2x3)
189
190func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
191  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
192  %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
193  return %0#0, %0#1, %0#2 : index, index, index
194}
195
196// -----
197
198mesh.mesh @mesh0(shape = 2x4)
199
200func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
201  // expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
202  %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
203  return %0#0, %0#1 : index, index
204}
205
206// -----
207
208mesh.mesh @mesh0(shape = 1x2x3)
209
210func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
211  // expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
212  %0:2 = mesh.process_multi_index on @mesh0 : index, index
213  return %0#0, %0#1 : index, index
214}
215
216// -----
217
218func.func @process_multi_index_invalid_mesh_name() -> (index) {
219  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
220  %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
221  return %0 : index
222}
223
224// -----
225
226func.func @process_linear_index_invalid_mesh_name() -> (index) {
227  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
228  %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
229  return %0 : index
230}
231
232// -----
233
234func.func @all_reduce_invalid_mesh_symbol(
235    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
236  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
237  %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum
238    : tensor<4xf32> -> tensor<4xf64>
239  return %0 : tensor<4xf64>
240}
241
242// -----
243
244mesh.mesh @mesh0(shape = 2x4)
245
246func.func @all_reduce_invalid_mesh_axis(
247    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
248  // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
249  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum
250    : tensor<4xf32> -> tensor<4xf64>
251  return %0 : tensor<4xf64>
252}
253
254// -----
255
256mesh.mesh @mesh0(shape = 2x4)
257
258func.func @all_reduce_duplicate_mesh_axis(
259    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
260  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
261  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum
262    : tensor<4xf32> -> tensor<4xf64>
263  return %0 : tensor<4xf64>
264}
265
266// -----
267
268mesh.mesh @mesh0(shape = 2x4)
269
270func.func @all_reduce_invalid_tensor_dimension_size(
271    %arg0 : tensor<4xf32>) -> tensor<5xf64> {
272  // expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
273  %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
274  return %0 : tensor<5xf64>
275}
276
277// -----
278
279func.func @all_gather_invalid_mesh_symbol(
280    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
281  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
282  %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
283    : tensor<4xf32> -> tensor<4xf32>
284  return %0 : tensor<4xf32>
285}
286
287// -----
288
289mesh.mesh @mesh0(shape = 2x4)
290
291func.func @all_gather_invalid_mesh_axis(
292    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
293  // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
294  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
295    : tensor<4xf32> -> tensor<4xf32>
296  return %0 : tensor<4xf32>
297}
298
299// -----
300
301mesh.mesh @mesh0(shape = 2x4)
302
303func.func @all_reduce_duplicate_mesh_axis(
304    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
305  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
306  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
307    : tensor<4xf32> -> tensor<4xf32>
308  return %0 : tensor<4xf32>
309}
310
311// -----
312
313mesh.mesh @mesh0(shape = 1)
314
315func.func @all_gather_invalid_non_gather_axis_dimension_size(
316    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
317  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
318  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
319    : tensor<3x4xf32> -> tensor<3x5xf32>
320  return %0 : tensor<3x5xf32>
321}
322
323// -----
324
325mesh.mesh @mesh0(shape = 1x2)
326
327func.func @all_gather_invalid_gather_axis_dimension_size(
328    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
329  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
330  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
331    : tensor<3x4xf32> -> tensor<3x5xf32>
332  return %0 : tensor<3x5xf32>
333}
334
335// -----
336
337mesh.mesh @mesh0(shape = 1)
338
339func.func @all_gather_invalid_gather_axis_dynamic_dimension(
340    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
341  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
342  %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
343    : tensor<?xf32> -> tensor<3xf32>
344  return %0 : tensor<3xf32>
345}
346
347// -----
348
349mesh.mesh @mesh0(shape = 1)
350
351func.func @all_gather_invalid_gather_axis(
352    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
353  // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
354  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
355    : tensor<3xf32> -> tensor<3xf32>
356  return %0 : tensor<3xf32>
357}
358
359// -----
360
361mesh.mesh @mesh0(shape = 1)
362
363func.func @all_gather_invalid_negative_gather_axis(
364    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
365  // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
366  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
367    : tensor<3xf32> -> tensor<3xf32>
368  return %0 : tensor<3xf32>
369}
370
371// -----
372
373mesh.mesh @mesh0(shape = 3)
374
375func.func @all_slice_duplicate_mesh_axis(
376    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
377  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
378  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0]
379    slice_axis = 0
380    : tensor<?xf32> -> tensor<?xf32>
381  return %0 : tensor<?xf32>
382}
383
384// -----
385
386mesh.mesh @mesh0(shape = 3)
387
388func.func @all_slice_invalid_dynamic_dimension(
389    %arg0 : tensor<?xf32>) -> tensor<2xf32> {
390  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
391  %0 = mesh.all_slice %arg0 on @mesh0
392    slice_axis = 0
393    : tensor<?xf32> -> tensor<2xf32>
394  return %0 : tensor<2xf32>
395}
396
397// -----
398
399mesh.mesh @mesh0(shape = 3)
400
401func.func @all_slice_invalid_static_dimension_size(
402    %arg0 : tensor<3xf32>) -> tensor<2xf32> {
403  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
404  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
405    slice_axis = 0
406    : tensor<3xf32> -> tensor<2xf32>
407  return %0 : tensor<2xf32>
408}
409
410// -----
411
412mesh.mesh @mesh0(shape = 3)
413
414func.func @all_slice_invalid_operand_static_dimension_size(
415    %arg0 : tensor<4xf32>) -> tensor<?xf32> {
416  // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
417  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
418    slice_axis = 0
419    : tensor<4xf32> -> tensor<?xf32>
420  return %0 : tensor<?xf32>
421}
422
423// -----
424
425func.func @all_to_all_invalid_mesh_symbol(
426    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
427  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
428  %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
429    split_axis = 1 concat_axis = 0
430    : tensor<3x6xi8> -> tensor<3x6xi8>
431  return %0 : tensor<3x6xi8>
432}
433
434// -----
435
436mesh.mesh @mesh0(shape = 1)
437
438func.func @all_to_all_duplicate_mesh_axis(
439    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
440  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
441  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
442    split_axis = 0 concat_axis = 0
443    : tensor<3x6xi8> -> tensor<3x6xi8>
444  return %0 : tensor<3x6xi8>
445}
446
447// -----
448
449mesh.mesh @mesh0(shape = ?x1)
450
451func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
452    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
453  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
454  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
455    split_axis = 0 concat_axis = 1
456    : tensor<3x6xi8> -> tensor<3x6xi8>
457  return %0 : tensor<3x6xi8>
458}
459
460// -----
461
462mesh.mesh @mesh0(shape = 1x1)
463
464func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
465    %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
466  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
467  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
468    split_axis = 0 concat_axis = 1
469    : tensor<?x6xi8> -> tensor<3x?xi8>
470  return %0 : tensor<3x?xi8>
471}
472
473// -----
474
475mesh.mesh @mesh0(shape = 1x1)
476
477func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
478    %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
479  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
480  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
481    split_axis = 0 concat_axis = 1
482    : tensor<3x?xi8> -> tensor<?x3xi8>
483  return %0 : tensor<?x3xi8>
484}
485
486// -----
487
488mesh.mesh @mesh0(shape = 3)
489
490func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
491    %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
492  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
493  %0 = mesh.all_to_all %arg0  on @mesh0 mesh_axes = [0]
494    split_axis = 0 concat_axis = 1
495    : tensor<3x2xi8> -> tensor<1x7xi8>
496  return %0 : tensor<1x7xi8>
497}
498
499// -----
500
501mesh.mesh @mesh0(shape = 3)
502
503func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
504    %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
505  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
506  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
507    split_axis = 0 concat_axis = 1
508    : tensor<3x2xi8> -> tensor<2x6xi8>
509  return %0 : tensor<2x6xi8>
510}
511
512// -----
513
514mesh.mesh @mesh0(shape = 3x?)
515
516func.func @broadcast_root_dimension_out_of_bounds(
517    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
518  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
519  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
520    root = [3]
521    : (tensor<2xi8>) -> tensor<2xi8>
522  return %0 : tensor<2xi8>
523}
524
525// -----
526
527mesh.mesh @mesh0(shape = 3x?)
528
529func.func @broadcast_root_wrong_number_dimensions(
530    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
531  // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
532  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
533    root = [2, 2]
534    : (tensor<2xi8>) -> tensor<2xi8>
535  return %0 : tensor<2xi8>
536}
537
538// -----
539
540mesh.mesh @mesh0(shape = 3x?)
541
542func.func @broadcast_different_input_and_result_type(
543    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
544  // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
545  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
546    root = [2]
547    : (tensor<2xi8>) -> tensor<2xi16>
548  return %0 : tensor<2xi16>
549}
550
551// -----
552
553mesh.mesh @mesh0(shape = 1)
554
555func.func @gather_wrong_return_element_type(
556    %arg0 : tensor<1xf32>) -> tensor<1xi8> {
557  // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
558  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
559    : (tensor<1xf32>) -> tensor<1xi8>
560  return %0 : tensor<1xi8>
561}
562
563// -----
564
565mesh.mesh @mesh0(shape = 1)
566
567func.func @gather_invalid_non_gather_axis_dimension_size(
568    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
569  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
570  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
571    : (tensor<3x4xf32>) -> tensor<3x5xf32>
572  return %0 : tensor<3x5xf32>
573}
574
575// -----
576
577mesh.mesh @mesh0(shape = 1x2)
578
579func.func @gather_invalid_gather_axis_dimension_size(
580    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
581  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
582  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
583    : (tensor<3x4xf32>) -> tensor<3x5xf32>
584  return %0 : tensor<3x5xf32>
585}
586
587// -----
588
589mesh.mesh @mesh0(shape = 1)
590
591func.func @gather_invalid_gather_axis_dynamic_dimension(
592    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
593  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
594  %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
595    : (tensor<?xf32>) -> tensor<3xf32>
596  return %0 : tensor<3xf32>
597}
598
599// -----
600
601mesh.mesh @mesh0(shape = 1)
602
603func.func @gather_invalid_gather_axis(
604    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
605  // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
606  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
607    : (tensor<3xf32>) -> tensor<3xf32>
608  return %0 : tensor<3xf32>
609}
610
611// -----
612
613mesh.mesh @mesh0(shape = 1)
614
615func.func @gather_invalid_negative_gather_axis(
616    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
617  // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
618  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
619    : (tensor<3xf32>) -> tensor<3xf32>
620  return %0 : tensor<3xf32>
621}
622
623// -----
624
625mesh.mesh @mesh0(shape = 3x?)
626
627func.func @gather_root_dimension_out_of_bounds(
628    %arg0 : tensor<2xi8>) -> tensor<6xi8> {
629  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
630  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
631    root = [3]
632    : (tensor<2xi8>) -> tensor<6xi8>
633  return %0 : tensor<6xi8>
634}
635
636// -----
637
638mesh.mesh @mesh0(shape = 3x?)
639
640func.func @gather_root_wrong_number_dimensions(
641    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
642  // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
643  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
644    root = [2, 2]
645    : (tensor<2xi8>) -> tensor<2xi8>
646  return %0 : tensor<2xi8>
647}
648
649// -----
650
651mesh.mesh @mesh0(shape = 3x?)
652
653func.func @receive_source_dimension_out_of_bounds(
654    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
655  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}}
656  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
657    source = [3]
658    : (tensor<2xi8>) -> tensor<2xi8>
659  return %0 : tensor<2xi8>
660}
661
662// -----
663
664mesh.mesh @mesh0(shape = 3x?)
665
666func.func @receive_source_wrong_number_dimensions(
667    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
668  // expected-error@+1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}}
669  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
670    source = [2, 2]
671    : (tensor<2xi8>) -> tensor<2xi8>
672  return %0 : tensor<2xi8>
673}
674
675// -----
676
677mesh.mesh @mesh0(shape = 3x?)
678
679func.func @receive_different_input_and_result_type(
680    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
681  // expected-error@+1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}}
682  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
683    source = [2]
684    : (tensor<2xi8>) -> tensor<2xi16>
685  return %0 : tensor<2xi16>
686}
687
688// -----
689
690mesh.mesh @mesh0(shape = 3x?)
691
692func.func @reduce_root_dimension_out_of_bounds(
693    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
694  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
695  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
696    root = [3]
697    : (tensor<2xi8>) -> tensor<2xi8>
698  return %0 : tensor<2xi8>
699}
700
701// -----
702
703mesh.mesh @mesh0(shape = 3x?)
704
705func.func @reduce_root_wrong_number_dimensions(
706    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
707  // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
708  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
709    root = [2, 2]
710    : (tensor<2xi8>) -> tensor<2xi8>
711  return %0 : tensor<2xi8>
712}
713
714// -----
715
716mesh.mesh @mesh0(shape = 3x?)
717
718func.func @reduce_different_input_and_result_shape(
719    %arg0 : tensor<2xi8>) -> tensor<3xi16> {
720  // expected-error@+1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}}
721  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
722    root = [2]
723    : (tensor<2xi8>) -> tensor<3xi16>
724  return %0 : tensor<3xi16>
725}
726
727// -----
728
729mesh.mesh @mesh0(shape = 3)
730
731func.func @reduce_scatter_duplicate_mesh_axis(
732    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
733  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
734  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
735    : tensor<?xf32> -> tensor<?xf64>
736  return %0 : tensor<?xf64>
737}
738
739// -----
740
741mesh.mesh @mesh0(shape = 3)
742
743func.func @reduce_scatter_invalid_dynamic_dimension(
744    %arg0 : tensor<?xf32>) -> tensor<2xf64> {
745  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
746  %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
747    : tensor<?xf32> -> tensor<2xf64>
748  return %0 : tensor<2xf64>
749}
750
751// -----
752
753mesh.mesh @mesh0(shape = 3)
754
755func.func @reduce_scatter_invalid_static_dimension_size(
756    %arg0 : tensor<3xf32>) -> tensor<2xf64> {
757  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
758  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
759    : tensor<3xf32> -> tensor<2xf64>
760  return %0 : tensor<2xf64>
761}
762
763// -----
764
765mesh.mesh @mesh0(shape = 3)
766
767func.func @reduce_scatter_invalid_operand_static_dimension_size(
768    %arg0 : tensor<4xf32>) -> tensor<?xf64> {
769  // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
770  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
771    : tensor<4xf32> -> tensor<?xf64>
772  return %0 : tensor<?xf64>
773}
774
775// -----
776
777mesh.mesh @mesh0(shape = 3)
778
779func.func @scatter_duplicate_mesh_axis(
780    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
781  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
782  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0]
783    scatter_axis = 0 root = [0, 0]
784    : (tensor<?xf32>) -> tensor<?xf32>
785  return %0 : tensor<?xf32>
786}
787
788// -----
789
790mesh.mesh @mesh0(shape = 3)
791
792func.func @scatter_invalid_dynamic_dimension(
793    %arg0 : tensor<?xf32>) -> tensor<2xf32> {
794  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
795  %0 = mesh.scatter %arg0 on @mesh0
796    scatter_axis = 0 root = []
797    : (tensor<?xf32>) -> tensor<2xf32>
798  return %0 : tensor<2xf32>
799}
800
801// -----
802
803mesh.mesh @mesh0(shape = 3)
804
805func.func @scatter_invalid_static_dimension_size(
806    %arg0 : tensor<3xf32>) -> tensor<2xf32> {
807  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
808  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
809    scatter_axis = 0 root = [1]
810    : (tensor<3xf32>) -> tensor<2xf32>
811  return %0 : tensor<2xf32>
812}
813
814// -----
815
816mesh.mesh @mesh0(shape = 3)
817
818func.func @scatter_invalid_operand_static_dimension_size(
819    %arg0 : tensor<4xf32>) -> tensor<?xf32> {
820  // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
821  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
822    scatter_axis = 0 root = [1]
823    : (tensor<4xf32>) -> tensor<?xf32>
824  return %0 : tensor<?xf32>
825}
826
827// -----
828
829mesh.mesh @mesh0(shape = 3x?)
830
831func.func @scatter_root_dimension_out_of_bounds(
832    %arg0 : tensor<3xi8>) -> tensor<1xi8> {
833  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
834  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
835    scatter_axis = 0 root = [3]
836    : (tensor<3xi8>) -> tensor<1xi8>
837  return %0 : tensor<1xi8>
838}
839
840// -----
841
842mesh.mesh @mesh0(shape = 3x?)
843
844func.func @scatter_root_wrong_number_dimensions(
845    %arg0 : tensor<3xi8>) -> tensor<1xi8> {
846  // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
847  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
848    scatter_axis = 0 root = [2, 2]
849    : (tensor<3xi8>) -> tensor<1xi8>
850  return %0 : tensor<1xi8>
851}
852
853// -----
854
855mesh.mesh @mesh0(shape = 3x?)
856
857func.func @send_destination_dimension_out_of_bounds(
858    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
859  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}}
860  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
861    destination = [3]
862    : (tensor<2xi8>) -> tensor<2xi8>
863  return %0 : tensor<2xi8>
864}
865
866// -----
867
868mesh.mesh @mesh0(shape = 3x?)
869
870func.func @send_destination_wrong_number_dimensions(
871    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
872  // expected-error@+1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}}
873  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
874    destination = [2, 2]
875    : (tensor<2xi8>) -> tensor<2xi8>
876  return %0 : tensor<2xi8>
877}
878
879// -----
880
881mesh.mesh @mesh0(shape = 3x?)
882
883func.func @send_different_input_and_result_type(
884    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
885  // expected-error@+1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}}
886  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
887    destination = [2]
888    : (tensor<2xi8>) -> tensor<2xi16>
889  return %0 : tensor<2xi16>
890}
891
892// -----
893
894func.func @shift_invalid_mesh_symbol(
895    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
896  // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
897  %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist
898    shift_axis = 0 offset = -2
899    : tensor<4xi8> -> tensor<4xi8>
900  return %0 : tensor<4xi8>
901}
902
903// -----
904
905mesh.mesh @mesh0(shape = 2x4)
906
907func.func @shift_invalid_mesh_axis(
908    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
909  // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
910  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2]
911        shift_axis = 2 offset = -2
912    : tensor<4xi8> -> tensor<4xi8>
913  return %0 : tensor<4xi8>
914}
915
916// -----
917
918mesh.mesh @mesh0(shape = 2x4)
919
920func.func @shift_duplicate_mesh_axis(
921    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
922  // expected-error@+1 {{Mesh axes contains duplicate elements.}}
923  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0]
924    shift_axis = 0 offset = -2
925    : tensor<4xi8> -> tensor<4xi8>
926  return %0 : tensor<4xi8>
927}
928
929// -----
930
931mesh.mesh @mesh0(shape = 2x4)
932
933func.func @shift_invalid_tensor_dimension_size(
934    %arg0 : tensor<4xi8>) -> tensor<5xi8> {
935  // expected-error@+1 {{'mesh.shift' op requires the same shape for all operands and results}}
936  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
937    shift_axis = 0 offset = 2
938    : tensor<4xi8> -> tensor<5xi8>
939  return %0 : tensor<5xi8>
940}
941
942// -----
943
944mesh.mesh @mesh0(shape = 2x4)
945
946func.func @shift_invalid_shift_axis(
947    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
948  // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}}
949  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
950    shift_axis = 1 offset = 2
951    : tensor<4xi8> -> tensor<4xi8>
952  return %0 : tensor<4xi8>
953}
954