1// RUN: mlir-opt --canonicalize %s | FileCheck %s 2 3mesh.mesh @mesh0(shape = 2x4) 4 5// CHECK-LABEL: func @all_reduce_empty_mesh_axes 6func.func @all_reduce_empty_mesh_axes( 7// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 8 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 9// CHECK-NOT: mesh.all_reduce 10 %0 = mesh.all_reduce %arg0 on @mesh0 11 mesh_axes = [] 12 : tensor<4xf32> -> tensor<4xf32> 13// CHECK: return %[[ARG]] 14 return %0 : tensor<4xf32> 15} 16 17// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type 18func.func @all_reduce_empty_mesh_axes_different_return_type( 19 %arg0 : tensor<4xf32>) -> tensor<4xf64> { 20// CHECK: mesh.all_reduce 21 %0 = mesh.all_reduce %arg0 on @mesh0 22// CHECK-NOT: mesh_axes 23 mesh_axes = [] 24 : tensor<4xf32> -> tensor<4xf64> 25 return %0 : tensor<4xf64> 26} 27 28// CHECK-LABEL: func @all_reduce_default_reduction 29func.func @all_reduce_default_reduction( 30 %arg0 : tensor<4xf32>) -> tensor<4xf64> { 31 %0 = mesh.all_reduce %arg0 on @mesh0 32 mesh_axes = [0] 33// CHECK-NOT: reduction 34 reduction = sum 35 : tensor<4xf32> -> tensor<4xf64> 36 return %0 : tensor<4xf64> 37} 38 39// CHECK-LABEL: func @all_to_all_empty_mesh_axes 40func.func @all_to_all_empty_mesh_axes( 41// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32> 42 %arg0 : tensor<8xf32>) -> tensor<8xf32> { 43// CHECK-NOT: mesh.all_to_all 44 %0 = mesh.all_to_all %arg0 on @mesh0 45 mesh_axes = [] 46 split_axis = 0 47 concat_axis = 0 48 : tensor<8xf32> -> tensor<8xf32> 49// CHECK: return %[[ARG]] 50 return %0 : tensor<8xf32> 51} 52 53// CHECK-LABEL: func @all_gather_empty_mesh_axes 54func.func @all_gather_empty_mesh_axes( 55// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 56 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 57// CHECK-NOT: mesh.all_gather 58 %0 = mesh.all_gather %arg0 on @mesh0 59 mesh_axes = [] 60 gather_axis = 0 61 : tensor<4xf32> -> tensor<4xf32> 62// CHECK: return %[[ARG]] 63 return %0 : tensor<4xf32> 64} 65 66// CHECK-LABEL: func @all_slice_empty_mesh_axes 67func.func @all_slice_empty_mesh_axes( 68// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 69 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 70// CHECK-NOT: mesh.scatter 71 %0 = mesh.all_slice %arg0 on @mesh0 72 mesh_axes = [] 73 slice_axis = 0 74 : tensor<4xf32> -> tensor<4xf32> 75// CHECK: return %[[ARG]] 76 return %0 : tensor<4xf32> 77} 78 79// CHECK-LABEL: func @broadcast_empty_mesh_axes 80func.func @broadcast_empty_mesh_axes( 81// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 82 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 83// CHECK-NOT: mesh.broadcast 84 %0 = mesh.broadcast %arg0 on @mesh0 85 mesh_axes = [] 86 root = [] 87 : (tensor<4xf32>) -> tensor<4xf32> 88// CHECK: return %[[ARG]] 89 return %0 : tensor<4xf32> 90} 91 92// CHECK-LABEL: func @gather_empty_mesh_axes 93func.func @gather_empty_mesh_axes( 94// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 95 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 96// CHECK-NOT: mesh.gather 97 %0 = mesh.gather %arg0 on @mesh0 98 mesh_axes = [] 99 gather_axis = 0 100 root = [] 101 : (tensor<4xf32>) -> tensor<4xf32> 102// CHECK: return %[[ARG]] 103 return %0 : tensor<4xf32> 104} 105 106// CHECK-LABEL: func @receive_empty_mesh_axes 107func.func @receive_empty_mesh_axes( 108// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 109 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 110// CHECK-NOT: mesh.recv 111 %0 = mesh.recv %arg0 on @mesh0 112 mesh_axes = [] 113 : (tensor<4xf32>) -> tensor<4xf32> 114// CHECK: return %[[ARG]] 115 return %0 : tensor<4xf32> 116} 117 118// CHECK-LABEL: func @reduce_empty_mesh_axes 119func.func @reduce_empty_mesh_axes( 120// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 121 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 122// CHECK-NOT: mesh.reduce 123 %0 = mesh.reduce %arg0 on @mesh0 124 mesh_axes = [] 125 root = [] 126 : (tensor<4xf32>) -> tensor<4xf32> 127// CHECK: return %[[ARG]] 128 return %0 : tensor<4xf32> 129} 130 131// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes 132func.func @reduce_scatter_empty_mesh_axes( 133// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 134 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 135// CHECK-NOT: mesh.reduce_scatter 136 %0 = mesh.reduce_scatter %arg0 on @mesh0 137 mesh_axes = [] 138 scatter_axis = 0 139 : tensor<4xf32> -> tensor<4xf32> 140// CHECK: return %[[ARG]] 141 return %0 : tensor<4xf32> 142} 143 144// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type 145func.func @reduce_scatter_empty_mesh_axes_different_return_type( 146 %arg0 : tensor<4xf32>) -> tensor<4xf64> { 147// CHECK: mesh.reduce_scatter 148 %0 = mesh.reduce_scatter %arg0 on @mesh0 149// CHECK-NOT: mesh_axes 150 mesh_axes = [] 151 scatter_axis = 0 152 : tensor<4xf32> -> tensor<4xf64> 153 return %0 : tensor<4xf64> 154} 155 156// CHECK-LABEL: func @reduce_scatter_default_reduction 157func.func @reduce_scatter_default_reduction( 158 %arg0 : tensor<4xf32>) -> tensor<2xf64> { 159 %0 = mesh.reduce_scatter %arg0 on @mesh0 160 mesh_axes = [0] 161// CHECK-NOT: reduction 162 reduction = sum 163 scatter_axis = 0 164 : tensor<4xf32> -> tensor<2xf64> 165 return %0 : tensor<2xf64> 166} 167 168// CHECK-LABEL: func @scatter_empty_mesh_axes 169func.func @scatter_empty_mesh_axes( 170// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 171 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 172// CHECK-NOT: mesh.scatter 173 %0 = mesh.scatter %arg0 on @mesh0 174 mesh_axes = [] 175 scatter_axis = 0 176 root = [] 177 : (tensor<4xf32>) -> tensor<4xf32> 178// CHECK: return %[[ARG]] 179 return %0 : tensor<4xf32> 180} 181 182// CHECK-LABEL: func @send_empty_mesh_axes 183func.func @send_empty_mesh_axes( 184// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> 185 %arg0 : tensor<4xf32>) -> tensor<4xf32> { 186// CHECK-NOT: mesh.send 187 %0 = mesh.send %arg0 on @mesh0 188 mesh_axes = [] 189 destination = [] 190 : (tensor<4xf32>) -> tensor<4xf32> 191// CHECK: return %[[ARG]] 192 return %0 : tensor<4xf32> 193} 194 195mesh.mesh @mesh4x4(shape = 4x4) 196// CHECK-LABEL: func @test_halo_sizes 197func.func @test_halo_sizes() -> !mesh.sharding { 198 %c2_i64 = arith.constant 2 : i64 199 // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding 200 %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding 201 return %sharding : !mesh.sharding 202} 203 204// CHECK-LABEL: func @test_shard_offs 205func.func @test_shard_offs() -> !mesh.sharding { 206 %c2_i64 = arith.constant 2 : i64 207 // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding 208 %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding 209 return %sharding : !mesh.sharding 210}