1// RUN: mlir-opt -split-input-file -verify-diagnostics %s 2 3// expected-error@+1 {{rank of mesh is expected to be a positive integer}} 4mesh.mesh @mesh0(shape = []) 5 6// ----- 7 8// expected-error@+1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}} 9mesh.mesh @mesh0(shape = -1) 10 11// ----- 12 13mesh.mesh @mesh0(shape = 2x4) 14 15func.func @mesh_axis_duplicated_different_subarray( 16 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 17 // expected-error@+1 {{mesh axis duplicated}} 18 %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding 19 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 20 return %0 : tensor<4x8xf32> 21} 22 23// ----- 24 25mesh.mesh @mesh0(shape = 2x4) 26 27func.func @mesh_axis_duplicated_same_subarray( 28 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 29 // expected-error@+1 {{mesh axis duplicated}} 30 %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding 31 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 32 return %0 : tensor<4x8xf32> 33} 34 35// ----- 36 37mesh.mesh @mesh0(shape = 2x4) 38 39func.func @mesh_axis_duplicated_bewteen_split_and_partial( 40 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 41 // expected-error@+1 {{mesh axis duplicated}} 42 %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[0] : !mesh.sharding 43 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 44 return %0 : tensor<4x8xf32> 45} 46 47// ----- 48 49mesh.mesh @mesh0(shape = 2x4) 50 51func.func @mesh_axis_negtive_in_split_part( 52 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 53 // expected-error@+1 {{mesh axis is expected to be non-negative}} 54 %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding 55 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 56 return %0 : tensor<4x8xf32> 57} 58 59// ----- 60 61mesh.mesh @mesh0(shape = 2x4) 62 63func.func @mesh_axis_negtive_in_partial( 64 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 65 // expected-error@+1 {{mesh axis is expected to be non-negative}} 66 %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[-1] : !mesh.sharding 67 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 68 return %0 : tensor<4x8xf32> 69} 70 71// ----- 72 73func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { 74 // expected-error@+1 {{custom op 'mesh.sharding' invalid kind of attribute specified}} 75 %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding 76 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 77 return 78} 79 80// ----- 81 82func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) { 83 // expected-error@+1 {{halo sizes must be specified for all split axes}} 84 %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding 85 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 86 return 87} 88 89// ----- 90 91func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) { 92 // expected-error@+1 {{halo sizes and shard offsets are mutually exclusive}} 93 %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding 94 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 95 return 96} 97 98// ----- 99 100mesh.mesh @mesh_dyn(shape = ?x?) 101func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) { 102 // expected-error@+1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}} 103 %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding 104 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 105 return 106} 107 108// ----- 109 110mesh.mesh @mesh0(shape = 2x4) 111func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) { 112 // expected-error@+1 {{sharded dims offsets has wrong size}} 113 %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding 114 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 115 return 116} 117 118// ----- 119 120mesh.mesh @mesh0(shape = 4) 121func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) { 122 // expected-error@+1 {{sharded dims offsets must be non-decreasing}} 123 %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding 124 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 125 return 126} 127 128// ----- 129 130mesh.mesh @mesh0(shape = 2x4) 131 132func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) { 133 // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} 134 %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index 135 return %0#0, %0#1 : index, index 136} 137 138// ----- 139 140mesh.mesh @mesh0(shape = 1x2x3) 141 142func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) { 143 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 144 %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index 145 return %0#0, %0#1, %0#2 : index, index, index 146} 147 148// ----- 149 150mesh.mesh @mesh0(shape = 2x4) 151 152func.func @mesh_shape_wrong_number_of_results() -> (index, index) { 153 // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} 154 %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index 155 return %0#0, %0#1 : index, index 156} 157 158// ----- 159 160mesh.mesh @mesh0(shape = 1x2x3) 161 162func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) { 163 // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} 164 %0:2 = mesh.mesh_shape @mesh0 : index, index 165 return %0#0, %0#1 : index, index 166} 167 168// ----- 169 170func.func @mesh_shape_invalid_mesh_name() -> (index) { 171 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 172 %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index 173 return %0#0 : index 174} 175 176// ----- 177 178mesh.mesh @mesh0(shape = 2x4) 179 180func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) { 181 // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} 182 %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index 183 return %0#0, %0#1 : index, index 184} 185 186// ----- 187 188mesh.mesh @mesh0(shape = 1x2x3) 189 190func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) { 191 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 192 %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index 193 return %0#0, %0#1, %0#2 : index, index, index 194} 195 196// ----- 197 198mesh.mesh @mesh0(shape = 2x4) 199 200func.func @process_multi_index_wrong_number_of_results() -> (index, index) { 201 // expected-error@+1 {{Unexpected number of results 2. Expected 1.}} 202 %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index 203 return %0#0, %0#1 : index, index 204} 205 206// ----- 207 208mesh.mesh @mesh0(shape = 1x2x3) 209 210func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) { 211 // expected-error@+1 {{Unexpected number of results 2. Expected 3.}} 212 %0:2 = mesh.process_multi_index on @mesh0 : index, index 213 return %0#0, %0#1 : index, index 214} 215 216// ----- 217 218func.func @process_multi_index_invalid_mesh_name() -> (index) { 219 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 220 %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index 221 return %0 : index 222} 223 224// ----- 225 226func.func @process_linear_index_invalid_mesh_name() -> (index) { 227 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 228 %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index 229 return %0 : index 230} 231 232// ----- 233 234func.func @all_reduce_invalid_mesh_symbol( 235 %arg0 : tensor<4xf32>) -> tensor<4xf64> { 236 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 237 %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum 238 : tensor<4xf32> -> tensor<4xf64> 239 return %0 : tensor<4xf64> 240} 241 242// ----- 243 244mesh.mesh @mesh0(shape = 2x4) 245 246func.func @all_reduce_invalid_mesh_axis( 247 %arg0 : tensor<4xf32>) -> tensor<4xf64> { 248 // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} 249 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum 250 : tensor<4xf32> -> tensor<4xf64> 251 return %0 : tensor<4xf64> 252} 253 254// ----- 255 256mesh.mesh @mesh0(shape = 2x4) 257 258func.func @all_reduce_duplicate_mesh_axis( 259 %arg0 : tensor<4xf32>) -> tensor<4xf64> { 260 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 261 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum 262 : tensor<4xf32> -> tensor<4xf64> 263 return %0 : tensor<4xf64> 264} 265 266// ----- 267 268mesh.mesh @mesh0(shape = 2x4) 269 270func.func @all_reduce_invalid_tensor_dimension_size( 271 %arg0 : tensor<4xf32>) -> tensor<5xf64> { 272 // expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}} 273 %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64> 274 return %0 : tensor<5xf64> 275} 276 277// ----- 278 279func.func @all_gather_invalid_mesh_symbol( 280 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 281 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 282 %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0 283 : tensor<4xf32> -> tensor<4xf32> 284 return %0 : tensor<4xf32> 285} 286 287// ----- 288 289mesh.mesh @mesh0(shape = 2x4) 290 291func.func @all_gather_invalid_mesh_axis( 292 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 293 // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} 294 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0 295 : tensor<4xf32> -> tensor<4xf32> 296 return %0 : tensor<4xf32> 297} 298 299// ----- 300 301mesh.mesh @mesh0(shape = 2x4) 302 303func.func @all_reduce_duplicate_mesh_axis( 304 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 305 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 306 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0 307 : tensor<4xf32> -> tensor<4xf32> 308 return %0 : tensor<4xf32> 309} 310 311// ----- 312 313mesh.mesh @mesh0(shape = 1) 314 315func.func @all_gather_invalid_non_gather_axis_dimension_size( 316 %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { 317 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}} 318 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 319 : tensor<3x4xf32> -> tensor<3x5xf32> 320 return %0 : tensor<3x5xf32> 321} 322 323// ----- 324 325mesh.mesh @mesh0(shape = 1x2) 326 327func.func @all_gather_invalid_gather_axis_dimension_size( 328 %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { 329 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}} 330 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 331 : tensor<3x4xf32> -> tensor<3x5xf32> 332 return %0 : tensor<3x5xf32> 333} 334 335// ----- 336 337mesh.mesh @mesh0(shape = 1) 338 339func.func @all_gather_invalid_gather_axis_dynamic_dimension( 340 %arg0 : tensor<?xf32>) -> tensor<3xf32> { 341 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} 342 %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0 343 : tensor<?xf32> -> tensor<3xf32> 344 return %0 : tensor<3xf32> 345} 346 347// ----- 348 349mesh.mesh @mesh0(shape = 1) 350 351func.func @all_gather_invalid_gather_axis( 352 %arg0 : tensor<3xf32>) -> tensor<3xf32> { 353 // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}} 354 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 355 : tensor<3xf32> -> tensor<3xf32> 356 return %0 : tensor<3xf32> 357} 358 359// ----- 360 361mesh.mesh @mesh0(shape = 1) 362 363func.func @all_gather_invalid_negative_gather_axis( 364 %arg0 : tensor<3xf32>) -> tensor<3xf32> { 365 // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}} 366 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 367 : tensor<3xf32> -> tensor<3xf32> 368 return %0 : tensor<3xf32> 369} 370 371// ----- 372 373mesh.mesh @mesh0(shape = 3) 374 375func.func @all_slice_duplicate_mesh_axis( 376 %arg0 : tensor<?xf32>) -> tensor<?xf32> { 377 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 378 %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0] 379 slice_axis = 0 380 : tensor<?xf32> -> tensor<?xf32> 381 return %0 : tensor<?xf32> 382} 383 384// ----- 385 386mesh.mesh @mesh0(shape = 3) 387 388func.func @all_slice_invalid_dynamic_dimension( 389 %arg0 : tensor<?xf32>) -> tensor<2xf32> { 390 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} 391 %0 = mesh.all_slice %arg0 on @mesh0 392 slice_axis = 0 393 : tensor<?xf32> -> tensor<2xf32> 394 return %0 : tensor<2xf32> 395} 396 397// ----- 398 399mesh.mesh @mesh0(shape = 3) 400 401func.func @all_slice_invalid_static_dimension_size( 402 %arg0 : tensor<3xf32>) -> tensor<2xf32> { 403 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} 404 %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] 405 slice_axis = 0 406 : tensor<3xf32> -> tensor<2xf32> 407 return %0 : tensor<2xf32> 408} 409 410// ----- 411 412mesh.mesh @mesh0(shape = 3) 413 414func.func @all_slice_invalid_operand_static_dimension_size( 415 %arg0 : tensor<4xf32>) -> tensor<?xf32> { 416 // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} 417 %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0] 418 slice_axis = 0 419 : tensor<4xf32> -> tensor<?xf32> 420 return %0 : tensor<?xf32> 421} 422 423// ----- 424 425func.func @all_to_all_invalid_mesh_symbol( 426 %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { 427 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 428 %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist 429 split_axis = 1 concat_axis = 0 430 : tensor<3x6xi8> -> tensor<3x6xi8> 431 return %0 : tensor<3x6xi8> 432} 433 434// ----- 435 436mesh.mesh @mesh0(shape = 1) 437 438func.func @all_to_all_duplicate_mesh_axis( 439 %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { 440 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 441 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0] 442 split_axis = 0 concat_axis = 0 443 : tensor<3x6xi8> -> tensor<3x6xi8> 444 return %0 : tensor<3x6xi8> 445} 446 447// ----- 448 449mesh.mesh @mesh0(shape = ?x1) 450 451func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group( 452 %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { 453 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}} 454 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] 455 split_axis = 0 concat_axis = 1 456 : tensor<3x6xi8> -> tensor<3x6xi8> 457 return %0 : tensor<3x6xi8> 458} 459 460// ----- 461 462mesh.mesh @mesh0(shape = 1x1) 463 464func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension( 465 %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> { 466 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} 467 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1] 468 split_axis = 0 concat_axis = 1 469 : tensor<?x6xi8> -> tensor<3x?xi8> 470 return %0 : tensor<3x?xi8> 471} 472 473// ----- 474 475mesh.mesh @mesh0(shape = 1x1) 476 477func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension( 478 %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> { 479 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}} 480 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1] 481 split_axis = 0 concat_axis = 1 482 : tensor<3x?xi8> -> tensor<?x3xi8> 483 return %0 : tensor<?x3xi8> 484} 485 486// ----- 487 488mesh.mesh @mesh0(shape = 3) 489 490func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size( 491 %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> { 492 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}} 493 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] 494 split_axis = 0 concat_axis = 1 495 : tensor<3x2xi8> -> tensor<1x7xi8> 496 return %0 : tensor<1x7xi8> 497} 498 499// ----- 500 501mesh.mesh @mesh0(shape = 3) 502 503func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size( 504 %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> { 505 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} 506 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0] 507 split_axis = 0 concat_axis = 1 508 : tensor<3x2xi8> -> tensor<2x6xi8> 509 return %0 : tensor<2x6xi8> 510} 511 512// ----- 513 514mesh.mesh @mesh0(shape = 3x?) 515 516func.func @broadcast_root_dimension_out_of_bounds( 517 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 518 // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} 519 %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] 520 root = [3] 521 : (tensor<2xi8>) -> tensor<2xi8> 522 return %0 : tensor<2xi8> 523} 524 525// ----- 526 527mesh.mesh @mesh0(shape = 3x?) 528 529func.func @broadcast_root_wrong_number_dimensions( 530 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 531 // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} 532 %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] 533 root = [2, 2] 534 : (tensor<2xi8>) -> tensor<2xi8> 535 return %0 : tensor<2xi8> 536} 537 538// ----- 539 540mesh.mesh @mesh0(shape = 3x?) 541 542func.func @broadcast_different_input_and_result_type( 543 %arg0 : tensor<2xi8>) -> tensor<2xi16> { 544 // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}} 545 %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0] 546 root = [2] 547 : (tensor<2xi8>) -> tensor<2xi16> 548 return %0 : tensor<2xi16> 549} 550 551// ----- 552 553mesh.mesh @mesh0(shape = 1) 554 555func.func @gather_wrong_return_element_type( 556 %arg0 : tensor<1xf32>) -> tensor<1xi8> { 557 // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}} 558 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0] 559 : (tensor<1xf32>) -> tensor<1xi8> 560 return %0 : tensor<1xi8> 561} 562 563// ----- 564 565mesh.mesh @mesh0(shape = 1) 566 567func.func @gather_invalid_non_gather_axis_dimension_size( 568 %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { 569 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}} 570 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0] 571 : (tensor<3x4xf32>) -> tensor<3x5xf32> 572 return %0 : tensor<3x5xf32> 573} 574 575// ----- 576 577mesh.mesh @mesh0(shape = 1x2) 578 579func.func @gather_invalid_gather_axis_dimension_size( 580 %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> { 581 // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}} 582 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0] 583 : (tensor<3x4xf32>) -> tensor<3x5xf32> 584 return %0 : tensor<3x5xf32> 585} 586 587// ----- 588 589mesh.mesh @mesh0(shape = 1) 590 591func.func @gather_invalid_gather_axis_dynamic_dimension( 592 %arg0 : tensor<?xf32>) -> tensor<3xf32> { 593 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}} 594 %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = [] 595 : (tensor<?xf32>) -> tensor<3xf32> 596 return %0 : tensor<3xf32> 597} 598 599// ----- 600 601mesh.mesh @mesh0(shape = 1) 602 603func.func @gather_invalid_gather_axis( 604 %arg0 : tensor<3xf32>) -> tensor<3xf32> { 605 // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}} 606 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0] 607 : (tensor<3xf32>) -> tensor<3xf32> 608 return %0 : tensor<3xf32> 609} 610 611// ----- 612 613mesh.mesh @mesh0(shape = 1) 614 615func.func @gather_invalid_negative_gather_axis( 616 %arg0 : tensor<3xf32>) -> tensor<3xf32> { 617 // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}} 618 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0] 619 : (tensor<3xf32>) -> tensor<3xf32> 620 return %0 : tensor<3xf32> 621} 622 623// ----- 624 625mesh.mesh @mesh0(shape = 3x?) 626 627func.func @gather_root_dimension_out_of_bounds( 628 %arg0 : tensor<2xi8>) -> tensor<6xi8> { 629 // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} 630 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 631 root = [3] 632 : (tensor<2xi8>) -> tensor<6xi8> 633 return %0 : tensor<6xi8> 634} 635 636// ----- 637 638mesh.mesh @mesh0(shape = 3x?) 639 640func.func @gather_root_wrong_number_dimensions( 641 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 642 // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} 643 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 644 root = [2, 2] 645 : (tensor<2xi8>) -> tensor<2xi8> 646 return %0 : tensor<2xi8> 647} 648 649// ----- 650 651mesh.mesh @mesh0(shape = 3x?) 652 653func.func @receive_source_dimension_out_of_bounds( 654 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 655 // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}} 656 %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] 657 source = [3] 658 : (tensor<2xi8>) -> tensor<2xi8> 659 return %0 : tensor<2xi8> 660} 661 662// ----- 663 664mesh.mesh @mesh0(shape = 3x?) 665 666func.func @receive_source_wrong_number_dimensions( 667 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 668 // expected-error@+1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}} 669 %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] 670 source = [2, 2] 671 : (tensor<2xi8>) -> tensor<2xi8> 672 return %0 : tensor<2xi8> 673} 674 675// ----- 676 677mesh.mesh @mesh0(shape = 3x?) 678 679func.func @receive_different_input_and_result_type( 680 %arg0 : tensor<2xi8>) -> tensor<2xi16> { 681 // expected-error@+1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}} 682 %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0] 683 source = [2] 684 : (tensor<2xi8>) -> tensor<2xi16> 685 return %0 : tensor<2xi16> 686} 687 688// ----- 689 690mesh.mesh @mesh0(shape = 3x?) 691 692func.func @reduce_root_dimension_out_of_bounds( 693 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 694 // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} 695 %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] 696 root = [3] 697 : (tensor<2xi8>) -> tensor<2xi8> 698 return %0 : tensor<2xi8> 699} 700 701// ----- 702 703mesh.mesh @mesh0(shape = 3x?) 704 705func.func @reduce_root_wrong_number_dimensions( 706 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 707 // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} 708 %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] 709 root = [2, 2] 710 : (tensor<2xi8>) -> tensor<2xi8> 711 return %0 : tensor<2xi8> 712} 713 714// ----- 715 716mesh.mesh @mesh0(shape = 3x?) 717 718func.func @reduce_different_input_and_result_shape( 719 %arg0 : tensor<2xi8>) -> tensor<3xi16> { 720 // expected-error@+1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}} 721 %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0] 722 root = [2] 723 : (tensor<2xi8>) -> tensor<3xi16> 724 return %0 : tensor<3xi16> 725} 726 727// ----- 728 729mesh.mesh @mesh0(shape = 3) 730 731func.func @reduce_scatter_duplicate_mesh_axis( 732 %arg0 : tensor<?xf32>) -> tensor<?xf64> { 733 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 734 %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0 735 : tensor<?xf32> -> tensor<?xf64> 736 return %0 : tensor<?xf64> 737} 738 739// ----- 740 741mesh.mesh @mesh0(shape = 3) 742 743func.func @reduce_scatter_invalid_dynamic_dimension( 744 %arg0 : tensor<?xf32>) -> tensor<2xf64> { 745 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} 746 %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0 747 : tensor<?xf32> -> tensor<2xf64> 748 return %0 : tensor<2xf64> 749} 750 751// ----- 752 753mesh.mesh @mesh0(shape = 3) 754 755func.func @reduce_scatter_invalid_static_dimension_size( 756 %arg0 : tensor<3xf32>) -> tensor<2xf64> { 757 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} 758 %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 759 : tensor<3xf32> -> tensor<2xf64> 760 return %0 : tensor<2xf64> 761} 762 763// ----- 764 765mesh.mesh @mesh0(shape = 3) 766 767func.func @reduce_scatter_invalid_operand_static_dimension_size( 768 %arg0 : tensor<4xf32>) -> tensor<?xf64> { 769 // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} 770 %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0 771 : tensor<4xf32> -> tensor<?xf64> 772 return %0 : tensor<?xf64> 773} 774 775// ----- 776 777mesh.mesh @mesh0(shape = 3) 778 779func.func @scatter_duplicate_mesh_axis( 780 %arg0 : tensor<?xf32>) -> tensor<?xf32> { 781 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 782 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0] 783 scatter_axis = 0 root = [0, 0] 784 : (tensor<?xf32>) -> tensor<?xf32> 785 return %0 : tensor<?xf32> 786} 787 788// ----- 789 790mesh.mesh @mesh0(shape = 3) 791 792func.func @scatter_invalid_dynamic_dimension( 793 %arg0 : tensor<?xf32>) -> tensor<2xf32> { 794 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} 795 %0 = mesh.scatter %arg0 on @mesh0 796 scatter_axis = 0 root = [] 797 : (tensor<?xf32>) -> tensor<2xf32> 798 return %0 : tensor<2xf32> 799} 800 801// ----- 802 803mesh.mesh @mesh0(shape = 3) 804 805func.func @scatter_invalid_static_dimension_size( 806 %arg0 : tensor<3xf32>) -> tensor<2xf32> { 807 // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} 808 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] 809 scatter_axis = 0 root = [1] 810 : (tensor<3xf32>) -> tensor<2xf32> 811 return %0 : tensor<2xf32> 812} 813 814// ----- 815 816mesh.mesh @mesh0(shape = 3) 817 818func.func @scatter_invalid_operand_static_dimension_size( 819 %arg0 : tensor<4xf32>) -> tensor<?xf32> { 820 // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} 821 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] 822 scatter_axis = 0 root = [1] 823 : (tensor<4xf32>) -> tensor<?xf32> 824 return %0 : tensor<?xf32> 825} 826 827// ----- 828 829mesh.mesh @mesh0(shape = 3x?) 830 831func.func @scatter_root_dimension_out_of_bounds( 832 %arg0 : tensor<3xi8>) -> tensor<1xi8> { 833 // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} 834 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] 835 scatter_axis = 0 root = [3] 836 : (tensor<3xi8>) -> tensor<1xi8> 837 return %0 : tensor<1xi8> 838} 839 840// ----- 841 842mesh.mesh @mesh0(shape = 3x?) 843 844func.func @scatter_root_wrong_number_dimensions( 845 %arg0 : tensor<3xi8>) -> tensor<1xi8> { 846 // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} 847 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0] 848 scatter_axis = 0 root = [2, 2] 849 : (tensor<3xi8>) -> tensor<1xi8> 850 return %0 : tensor<1xi8> 851} 852 853// ----- 854 855mesh.mesh @mesh0(shape = 3x?) 856 857func.func @send_destination_dimension_out_of_bounds( 858 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 859 // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}} 860 %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] 861 destination = [3] 862 : (tensor<2xi8>) -> tensor<2xi8> 863 return %0 : tensor<2xi8> 864} 865 866// ----- 867 868mesh.mesh @mesh0(shape = 3x?) 869 870func.func @send_destination_wrong_number_dimensions( 871 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 872 // expected-error@+1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}} 873 %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] 874 destination = [2, 2] 875 : (tensor<2xi8>) -> tensor<2xi8> 876 return %0 : tensor<2xi8> 877} 878 879// ----- 880 881mesh.mesh @mesh0(shape = 3x?) 882 883func.func @send_different_input_and_result_type( 884 %arg0 : tensor<2xi8>) -> tensor<2xi16> { 885 // expected-error@+1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}} 886 %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0] 887 destination = [2] 888 : (tensor<2xi8>) -> tensor<2xi16> 889 return %0 : tensor<2xi16> 890} 891 892// ----- 893 894func.func @shift_invalid_mesh_symbol( 895 %arg0 : tensor<4xi8>) -> tensor<4xi8> { 896 // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}} 897 %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist 898 shift_axis = 0 offset = -2 899 : tensor<4xi8> -> tensor<4xi8> 900 return %0 : tensor<4xi8> 901} 902 903// ----- 904 905mesh.mesh @mesh0(shape = 2x4) 906 907func.func @shift_invalid_mesh_axis( 908 %arg0 : tensor<4xi8>) -> tensor<4xi8> { 909 // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}} 910 %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2] 911 shift_axis = 2 offset = -2 912 : tensor<4xi8> -> tensor<4xi8> 913 return %0 : tensor<4xi8> 914} 915 916// ----- 917 918mesh.mesh @mesh0(shape = 2x4) 919 920func.func @shift_duplicate_mesh_axis( 921 %arg0 : tensor<4xi8>) -> tensor<4xi8> { 922 // expected-error@+1 {{Mesh axes contains duplicate elements.}} 923 %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0] 924 shift_axis = 0 offset = -2 925 : tensor<4xi8> -> tensor<4xi8> 926 return %0 : tensor<4xi8> 927} 928 929// ----- 930 931mesh.mesh @mesh0(shape = 2x4) 932 933func.func @shift_invalid_tensor_dimension_size( 934 %arg0 : tensor<4xi8>) -> tensor<5xi8> { 935 // expected-error@+1 {{'mesh.shift' op requires the same shape for all operands and results}} 936 %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0] 937 shift_axis = 0 offset = 2 938 : tensor<4xi8> -> tensor<5xi8> 939 return %0 : tensor<5xi8> 940} 941 942// ----- 943 944mesh.mesh @mesh0(shape = 2x4) 945 946func.func @shift_invalid_shift_axis( 947 %arg0 : tensor<4xi8>) -> tensor<4xi8> { 948 // expected-error@+1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}} 949 %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0] 950 shift_axis = 1 offset = 2 951 : tensor<4xi8> -> tensor<4xi8> 952 return %0 : tensor<4xi8> 953} 954