1// RUN: mlir-opt %s | mlir-opt | FileCheck %s 2 3// CHECK: mesh.mesh @mesh0 4mesh.mesh @mesh0(shape = 2x2x4) 5 6// CHECK: mesh.mesh @mesh1(shape = 4x?) 7mesh.mesh @mesh1(shape = 4x?) 8 9// CHECK: mesh.mesh @mesh2(shape = ?x4) 10mesh.mesh @mesh2(shape = ?x4) 11 12// CHECK: mesh.mesh @mesh3(shape = ?x?) 13mesh.mesh @mesh3(shape = ?x?) 14 15mesh.mesh @mesh4(shape = 3) 16 17// CHECK: mesh.mesh @mesh5(shape = ?) 18mesh.mesh @mesh5(shape = ?) 19 20// CHECK-LABEL: func @mesh_shard_op_fully_replicated 21// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 22func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 23 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding 24 %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding 25 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 26 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 27 return %0 : tensor<4x8xf32> 28} 29 30// CHECK-LABEL: func @mesh_shard_op_1st_dim 31// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 32func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 33 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding 34 %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 35 36 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 37 return %0 : tensor<4x8xf32> 38} 39 40// CHECK-LABEL: func @mesh_shard_op_2nd_dim 41// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 42func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 43 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding 44 %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding 45 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 46 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 47 return %0 : tensor<4x8xf32> 48} 49 50// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim 51func.func @mesh_shard_op_1st_and_3rd_dim( 52 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32> 53 %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { 54 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding 55 %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding 56 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32> 57 %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32> 58 return %0 : tensor<4x8x16xf32> 59} 60 61// CHECK-LABEL: func @mesh_shard_op_partial_max 62func.func @mesh_shard_op_partial_max( 63 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 64 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 65 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = max [1] : !mesh.sharding 66 %s = mesh.sharding @mesh3 split_axes = [[0]] partial = max[1] : !mesh.sharding 67 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 68 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 69 return %0 : tensor<4x8xf32> 70} 71 72// CHECK-LABEL: func @mesh_shard_op_partial_min 73func.func @mesh_shard_op_partial_min( 74 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 75 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 76 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = min [1] : !mesh.sharding 77 %s = mesh.sharding @mesh3 split_axes = [[0]] partial = min[1] : !mesh.sharding 78 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 79 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 80 return %0 : tensor<4x8xf32> 81} 82 83// CHECK-LABEL: func @mesh_shard_op_partial_generic 84func.func @mesh_shard_op_partial_generic( 85 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 86 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 87 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = generic [1] : !mesh.sharding 88 %s = mesh.sharding @mesh3 split_axes = [[0]] partial = generic[1] : !mesh.sharding 89 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 90 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 91 return %0 : tensor<4x8xf32> 92} 93 94// CHECK-LABEL: func @mesh_shard_op_partial_sum 95func.func @mesh_shard_op_partial_sum( 96 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 97 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 98 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1] : !mesh.sharding 99 %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1] : !mesh.sharding 100 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 101 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 102 return %0 : tensor<4x8xf32> 103} 104 105// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes 106func.func @mesh_shard_op_partial_sum_multi_axes( 107 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 108 %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { 109 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1, 2] : !mesh.sharding 110 %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1, 2] : !mesh.sharding 111 // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32> 112 %0 = mesh.shard %arg0 to %s : tensor<4x8xf32> 113 return %0 : tensor<4x8xf32> 114} 115 116// CHECK-LABEL: func @mesh_shard_op_two_users 117// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> 118func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> 119 (tensor<4x8xf32>, tensor<4x8xf32>) { 120 // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding 121 %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding 122 %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32> 123 // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding 124 %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding 125 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32> 126 // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding 127 %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding 128 %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32> 129 return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32> 130} 131 132// CHECK-LABEL: func @mesh_shard_halo_sizes 133func.func @mesh_shard_halo_sizes() -> () { 134 // CHECK: %[[C3:.*]] = arith.constant 3 : i64 135 %c3 = arith.constant 3 : i64 136 // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding 137 %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding 138 // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding 139 %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding 140 return 141} 142 143// CHECK-LABEL: func @mesh_shard_dims_sizes 144func.func @mesh_shard_dims_sizes() -> () { 145 // CHECK: %[[C3:.*]] = arith.constant 3 : i64 146 %c3 = arith.constant 3 : i64 147 // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding 148 %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding 149 // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding 150 %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding 151 return 152} 153 154// CHECK-LABEL: func @mesh_shard_shape 155func.func @mesh_shard_shape() { 156 // CHECK: %[[C3:.*]] = arith.constant 3 : index 157 %c3 = arith.constant 3 : index 158 // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding 159 %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding 160 // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index 161 %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index 162 // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index 163 %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index 164 return 165} 166 167// CHECK-LABEL: func @mesh_shape 168func.func @mesh_shape() -> (index, index) { 169 // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index 170 %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index 171 // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index 172 return %0#0, %0#1 : index, index 173} 174 175// CHECK-LABEL: func @mesh_shape_default_axes 176func.func @mesh_shape_default_axes() -> (index, index, index) { 177 // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index 178 %0:3 = mesh.mesh_shape @mesh0 : index, index, index 179 // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index 180 return %0#0, %0#1, %0#2 : index, index, index 181} 182 183// CHECK-LABEL: func @mesh_shape_empty_axes 184func.func @mesh_shape_empty_axes() -> (index, index, index) { 185 // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index 186 %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index 187 // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index 188 return %0#0, %0#1, %0#2 : index, index, index 189} 190 191// CHECK-LABEL: func @process_multi_index 192func.func @process_multi_index() -> (index, index) { 193 // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index 194 %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index 195 // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index 196 return %0#0, %0#1 : index, index 197} 198 199// CHECK-LABEL: func @process_multi_index_default_axes 200func.func @process_multi_index_default_axes() -> (index, index, index) { 201 // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index 202 %0:3 = mesh.process_multi_index on @mesh0 : index, index, index 203 // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index 204 return %0#0, %0#1, %0#2 : index, index, index 205} 206 207// CHECK-LABEL: func @process_multi_index_empty_axes 208func.func @process_multi_index_empty_axes() -> (index, index, index) { 209 // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index 210 %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index 211 // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index 212 return %0#0, %0#1, %0#2 : index, index, index 213} 214 215// CHECK-LABEL: func @process_linear_index 216func.func @process_linear_index() -> index { 217 // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index 218 %0 = mesh.process_linear_index on @mesh0 : index 219 // CHECK: return %[[RES]] : index 220 return %0 : index 221} 222 223// CHECK-LABEL: func @all_reduce 224func.func @all_reduce( 225 // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> 226 %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> { 227 // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max 228 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64> 229 %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max 230 : tensor<3x4xf32> -> tensor<3x4xf64> 231 return %0 : tensor<3x4xf64> 232} 233 234// CHECK-LABEL: func @all_gather 235func.func @all_gather( 236 // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> 237 %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> { 238 // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 239 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32> 240 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 241 : tensor<3x4xf32> -> tensor<3x16xf32> 242 return %0 : tensor<3x16xf32> 243} 244 245// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor 246func.func @all_gather_dynamic_dims_in_tensor( 247 // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32> 248 %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> { 249 // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 250 // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32> 251 %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 252 : tensor<?x?xf32> -> tensor<?x?xf32> 253 return %0 : tensor<?x?xf32> 254} 255 256// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh 257func.func @all_gather_dynamic_dims_in_mesh( 258 // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32> 259 %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> { 260 // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1 261 // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32> 262 %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1 263 : tensor<5x6xf32> -> tensor<5x?xf32> 264 return %0 : tensor<5x?xf32> 265} 266 267// CHECK-LABEL: func @all_slice_static_dimensions 268func.func @all_slice_static_dimensions( 269 // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> 270 %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { 271 // CHECK-NEXT: mesh.all_slice %[[ARG]] 272 // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1 273 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32> 274 %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1 275 : tensor<3x4xf32> -> tensor<3x1xf32> 276 return %0 : tensor<3x1xf32> 277} 278 279// CHECK-LABEL: func @all_slice_dynamic_dimensions 280func.func @all_slice_dynamic_dimensions( 281 // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> 282 %arg0 : tensor<?xf32>) -> tensor<?xf32> { 283 // CHECK-NEXT: mesh.all_slice %[[ARG]] 284 // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0 285 // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32> 286 %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0 287 : tensor<?xf32> -> tensor<?xf32> 288 return %0 : tensor<?xf32> 289} 290 291// CHECK-LABEL: func @all_to_all 292func.func @all_to_all( 293 // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> 294 %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { 295 // CHECK-NEXT: mesh.all_to_all %[[ARG]] 296 // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 297 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8> 298 %0 = mesh.all_to_all %arg0 on @mesh4 299 split_axis = 1 concat_axis = 0 300 : tensor<3x6xi8> -> tensor<3x6xi8> 301 return %0 : tensor<3x6xi8> 302} 303 304// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result 305func.func @all_to_all_dynamic_dims_in_result( 306 // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> 307 %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> { 308 // CHECK-NEXT: mesh.all_to_all %[[ARG]] 309 // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 310 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8> 311 %0 = mesh.all_to_all %arg0 on @mesh4 312 split_axis = 1 concat_axis = 0 313 : tensor<3x6xi8> -> tensor<3x?xi8> 314 return %0 : tensor<3x?xi8> 315} 316 317// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size 318func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size( 319 // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8> 320 %arg0 : tensor<3xi8>) -> tensor<3xi8> { 321 // CHECK-NEXT: mesh.all_to_all %[[ARG]] 322 // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0 323 // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8> 324 %0 = mesh.all_to_all %arg0 on @mesh4 325 split_axis = 0 concat_axis = 0 326 : tensor<3xi8> -> tensor<3xi8> 327 return %0 : tensor<3xi8> 328} 329 330// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size 331func.func @all_to_all_non_divisible_split_axis_size( 332 // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8> 333 %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> { 334 // CHECK-NEXT: mesh.all_to_all %[[ARG]] 335 // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1 336 // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8> 337 %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1] 338 split_axis = 0 concat_axis = 1 339 : tensor<2x3xi8> -> tensor<?x12xi8> 340 return %0 : tensor<?x12xi8> 341} 342 343// CHECK-LABEL: func @broadcast_static_root 344func.func @broadcast_static_root( 345 // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> 346 %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { 347 // CHECK-NEXT: mesh.broadcast %[[ARG]] 348 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 349 // CHECK-SAME: root = [0, 1] 350 // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8> 351 %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] 352 root = [0, 1] 353 : (tensor<3x6xi8>) -> tensor<3x6xi8> 354 return %0 : tensor<3x6xi8> 355} 356 357// CHECK-LABEL: func @broadcast_dynamic_root 358func.func @broadcast_dynamic_root( 359 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8> 360 %arg0 : tensor<3x6xi8>, 361 // CHECK-SAME: %[[ARG1:.*]]: index 362 %arg1 : index 363 ) -> tensor<3x6xi8> { 364 // CHECK-NEXT: mesh.broadcast %[[ARG0]] 365 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 366 // CHECK-SAME: root = [1, %[[ARG1]]] 367 // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8> 368 %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] 369 root = [1, %arg1] 370 : (tensor<3x6xi8>, index) -> tensor<3x6xi8> 371 return %0 : tensor<3x6xi8> 372} 373 374// CHECK-LABEL: func @gather_static_root 375func.func @gather_static_root( 376 // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> 377 %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> { 378 // CHECK-NEXT: mesh.gather %[[ARG]] 379 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 380 // CHECK-SAME: gather_axis = 0 381 // CHECK-SAME: root = [0, 1] 382 // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8> 383 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] 384 gather_axis = 0 385 root = [0, 1] 386 : (tensor<3x6xi8>) -> tensor<24x6xi8> 387 return %0 : tensor<24x6xi8> 388} 389 390// CHECK-LABEL: func @gather_dynamic_root 391func.func @gather_dynamic_root( 392 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8> 393 %arg0 : tensor<3x6xi8>, 394 // CHECK-SAME: %[[ARG1:.*]]: index 395 %arg1 : index 396 ) -> tensor<24x6xi8> { 397 // CHECK-NEXT: mesh.gather %[[ARG0]] 398 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 399 // CHECK-SAME: gather_axis = 0 400 // CHECK-SAME: root = [1, %[[ARG1]]] 401 // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8> 402 %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] 403 gather_axis = 0 404 root = [1, %arg1] 405 : (tensor<3x6xi8>, index) -> tensor<24x6xi8> 406 return %0 : tensor<24x6xi8> 407} 408 409// CHECK-LABEL: func @receive_static_source 410func.func @receive_static_source( 411 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 412 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 413 // CHECK-NEXT: mesh.recv %[[ARG]] 414 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 415 // CHECK-SAME: source = [0, 1] 416 // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> 417 %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] 418 source = [0, 1] 419 : (tensor<2xi8>) -> tensor<2xi8> 420 return %0 : tensor<2xi8> 421} 422 423// CHECK-LABEL: func @receive_dynamic_source 424func.func @receive_dynamic_source( 425 // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8> 426 %arg0 : tensor<2xi8>, 427 // CHECK-SAME: %[[ARG1:.*]]: index 428 %arg1 : index 429 ) -> tensor<2xi8> { 430 // CHECK-NEXT: mesh.recv %[[ARG0]] 431 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 432 // CHECK-SAME: source = [1, %[[ARG1]]] 433 // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> 434 %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] 435 source = [1, %arg1] 436 : (tensor<2xi8>, index) -> tensor<2xi8> 437 return %0 : tensor<2xi8> 438} 439 440// CHECK-LABEL: func @receive_no_source 441func.func @receive_no_source( 442 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 443 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 444 // CHECK-NEXT: mesh.recv %[[ARG]] 445 // CHECK-NOT: source 446 %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] 447 : (tensor<2xi8>) -> tensor<2xi8> 448 return %0 : tensor<2xi8> 449} 450 451// CHECK-LABEL: func @reduce_static_root 452func.func @reduce_static_root( 453 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 454 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 455 // CHECK-NEXT: mesh.reduce %[[ARG]] 456 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 457 // CHECK-SAME: root = [0, 1] 458 // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> 459 %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] 460 root = [0, 1] 461 : (tensor<2xi8>) -> tensor<2xi8> 462 return %0 : tensor<2xi8> 463} 464 465// CHECK-LABEL: func @reduce_dynamic_root 466func.func @reduce_dynamic_root( 467 // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8> 468 %arg0 : tensor<2xi8>, 469 // CHECK-SAME: %[[ARG1:.*]]: index 470 %arg1 : index 471 ) -> tensor<2xi8> { 472 // CHECK-NEXT: mesh.reduce %[[ARG0]] 473 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 474 // CHECK-SAME: root = [1, %[[ARG1]]] 475 // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> 476 %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] 477 root = [1, %arg1] 478 : (tensor<2xi8>, index) -> tensor<2xi8> 479 return %0 : tensor<2xi8> 480} 481 482// CHECK-LABEL: func @reduce_different_return_element_type 483func.func @reduce_different_return_element_type( 484 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 485 %arg0 : tensor<2xi8>) -> tensor<2xi16> { 486 // CHECK-NEXT: mesh.reduce %[[ARG]] 487 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 488 // CHECK-SAME: root = [0, 1] 489 // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16> 490 %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] 491 root = [0, 1] 492 : (tensor<2xi8>) -> tensor<2xi16> 493 return %0 : tensor<2xi16> 494} 495 496// CHECK-LABEL: func @reduce_scatter_static_dimensions 497func.func @reduce_scatter_static_dimensions( 498 // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> 499 %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> { 500 // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] 501 // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1 502 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64> 503 %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2] 504 reduction = max scatter_axis = 1 505 : tensor<3x4xf32> -> tensor<3x1xf64> 506 return %0 : tensor<3x1xf64> 507} 508 509// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions 510func.func @reduce_scatter_dynamic_dimensions( 511 // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> 512 %arg0 : tensor<?xf32>) -> tensor<?xf64> { 513 // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] 514 // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 515 // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64> 516 %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 517 : tensor<?xf32> -> tensor<?xf64> 518 return %0 : tensor<?xf64> 519} 520 521// CHECK-LABEL: func @scatter_static_dimensions 522func.func @scatter_static_dimensions( 523 // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> 524 %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { 525 // CHECK-NEXT: mesh.scatter %[[ARG]] 526 // CHECK-SAME: on @mesh0 mesh_axes = [2] 527 // CHECK-SAME: scatter_axis = 1 root = [1] 528 // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32> 529 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2] 530 scatter_axis = 1 root = [1] 531 : (tensor<3x4xf32>) -> tensor<3x1xf32> 532 return %0 : tensor<3x1xf32> 533} 534 535// CHECK-LABEL: func @scatter_dynamic_dimensions 536func.func @scatter_dynamic_dimensions( 537 // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> 538 %arg0 : tensor<?xf32>) -> tensor<?xf32> { 539 // CHECK-NEXT: mesh.scatter %[[ARG]] 540 // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] 541 // CHECK-SAME: scatter_axis = 0 root = [1, 2] 542 // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32> 543 %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1] 544 scatter_axis = 0 root = [1, 2] 545 : (tensor<?xf32>) -> tensor<?xf32> 546 return %0 : tensor<?xf32> 547} 548 549// CHECK-LABEL: func @scatter_dynamic_root 550func.func @scatter_dynamic_root( 551 // CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8> 552 %arg0 : tensor<8xi8>, 553 // CHECK-SAME: %[[ARG1:.*]]: index 554 %arg1 : index 555 ) -> tensor<1xi8> { 556 // CHECK-NEXT: mesh.scatter %[[ARG0]] 557 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 558 // CHECK-SAME: scatter_axis = 0 559 // CHECK-SAME: root = [1, %[[ARG1]]] 560 // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8> 561 %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2] 562 scatter_axis = 0 563 root = [1, %arg1] 564 : (tensor<8xi8>, index) -> tensor<1xi8> 565 return %0 : tensor<1xi8> 566} 567 568// CHECK-LABEL: func @send_static_destination 569func.func @send_static_destination( 570 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 571 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 572 // CHECK-NEXT: mesh.send %[[ARG]] 573 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 574 // CHECK-SAME: destination = [0, 1] 575 // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> 576 %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] 577 destination = [0, 1] 578 : (tensor<2xi8>) -> tensor<2xi8> 579 return %0 : tensor<2xi8> 580} 581 582// CHECK-LABEL: func @send_dynamic_destination 583func.func @send_dynamic_destination( 584 // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8> 585 %arg0 : tensor<2xi8>, 586 // CHECK-SAME: %[[ARG1:.*]]: index 587 %arg1 : index 588 ) -> tensor<2xi8> { 589 // CHECK-NEXT: mesh.send %[[ARG0]] 590 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 591 // CHECK-SAME: destination = [1, %[[ARG1]]] 592 // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> 593 %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] 594 destination = [1, %arg1] 595 : (tensor<2xi8>, index) -> tensor<2xi8> 596 return %0 : tensor<2xi8> 597} 598 599// CHECK-LABEL: func @shift 600func.func @shift( 601 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 602 %arg0 : tensor<2xi8>) -> tensor<2xi8> { 603 // CHECK-NEXT: mesh.shift %[[ARG]] 604 // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] 605 // CHECK-SAME: shift_axis = 2 offset = -2 rotate 606 // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8> 607 %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2] 608 shift_axis = 2 offset = -2 rotate 609 : tensor<2xi8> -> tensor<2xi8> 610 return %0 : tensor<2xi8> 611} 612 613// CHECK-LABEL: func @update_halo 614func.func @update_halo( 615 // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8> 616 %arg0 : memref<12x12xi8>) { 617 // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64 618 // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0 619 // CHECK-SAME: split_axes = {{\[\[}}0]] 620 // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> 621 %c2 = arith.constant 2 : i64 622 %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] 623 halo_sizes = [2, %c2] : memref<12x12xi8> 624 // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0 625 // CHECK-SAME: split_axes = {{\[\[}}0], [1]] 626 // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> 627 %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]] 628 halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> 629 return 630} 631