1// RUN: mlir-opt \ 2// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \ 3// RUN: %s | FileCheck %s 4 5mesh.mesh @mesh_1d(shape = 2) 6 7// CHECK-LABEL: func @full_replication 8func.func @full_replication( 9 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 10 %arg0: tensor<2xi8> 11// CHECK-SAME: -> tensor<2xi8> { 12) -> tensor<2xi8> { 13 %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 14 %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> 15 %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 16 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> 17 // CHECK: return %[[ARG]] : tensor<2xi8> 18 return %1 : tensor<2xi8> 19} 20 21// CHECK-LABEL: func @sharding_triplet 22func.func @sharding_triplet( 23 // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32> 24 %arg0: tensor<2xf32> 25// CHECK-SAME: ) -> tensor<2xf32> { 26) -> tensor<2xf32> { 27 // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32> 28 %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 29 %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32> 30 %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 31 %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32> 32 %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 33 %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32> 34 // CHECK: return %[[ALL_GATHER]] : tensor<2xf32> 35 return %sharding_annotated_1 : tensor<2xf32> 36} 37 38 39// CHECK-LABEL: func @move_split_axis 40func.func @move_split_axis( 41 // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8> 42 %arg0: tensor<2x2xi8> 43// CHECK-SAME: -> tensor<2x1xi8> { 44) -> tensor<2x2xi8> { 45 // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d 46 // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8> 47 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 48 %0 = mesh.shard %arg0 to %s0 : tensor<2x2xi8> 49 %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding 50 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2x2xi8> 51 // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8> 52 return %1 : tensor<2x2xi8> 53} 54 55// CHECK-LABEL: func @non_tensor_value 56func.func @non_tensor_value( 57 // CHECK-SAME: %[[ARG:.*]]: i8 58 %arg0: i8 59// CHECK-SAME: -> i8 { 60) -> i8 { 61 // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8 62 %0 = arith.addi %arg0, %arg0 : i8 63 // CHECK: return %[[RES]] : i8 64 return %0 : i8 65} 66 67// CHECK-LABEL: func @unary_elementwise 68func.func @unary_elementwise( 69 // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8> 70 %arg0: tensor<2xi8> 71// CHECK-SAME: -> tensor<1xi8> { 72) -> tensor<2xi8> { 73 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 74 %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> 75 %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 76 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> 77 // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8> 78 %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> 79 %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 80 %3 = mesh.shard %2 to %s3 : tensor<2xi8> 81 %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 82 %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> 83 // CHECK: return %[[RES]] : tensor<1xi8> 84 return %4 : tensor<2xi8> 85} 86 87// full replication -> shard axis -> abs -> shard axis -> full replication 88// CHECK-LABEL: func @unary_elementwise_with_resharding 89func.func @unary_elementwise_with_resharding( 90 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 91 %arg0: tensor<2xi8> 92// CHECK-SAME: -> tensor<2xi8> { 93) -> tensor<2xi8> { 94 // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 95 // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> 96 %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 97 %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> 98 %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 99 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> 100 // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8> 101 %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> 102 // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d 103 // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> 104 %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 105 %3 = mesh.shard %2 to %s3 : tensor<2xi8> 106 %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 107 %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> 108 // CHECK: return %[[RES]] : tensor<2xi8> 109 return %4 : tensor<2xi8> 110} 111 112// CHECK-LABEL: func @binary_elementwise 113func.func @binary_elementwise( 114 // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>, 115 %arg0: tensor<2xi8>, 116 // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8> 117 %arg1: tensor<2xi8> 118// CHECK-SAME: -> tensor<1xi8> { 119) -> tensor<2xi8> { 120 %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 121 %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2xi8> 122 %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 123 %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8> 124 %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 125 %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded : tensor<2xi8> 126 %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 127 %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8> 128 // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> 129 %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8> 130 %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 131 %op_res_sharded = mesh.shard %op_res to %sop_res_sharded : tensor<2xi8> 132 %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 133 %res = mesh.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8> 134 // CHECK: return %[[RES]] : tensor<1xi8> 135 return %res : tensor<2xi8> 136} 137 138// reshard 139// abs 140// reshard 141// abs 142// reshard 143// CHECK-LABEL: func @multiple_chained_ops 144func.func @multiple_chained_ops( 145 // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> 146 %arg0: tensor<2xi8> 147// CHECK-SAME: -> tensor<1xi8> { 148) -> tensor<2xi8> { 149 // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0 150 // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> 151 %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 152 %0 = mesh.shard %arg0 to %s0 : tensor<2xi8> 153 %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 154 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8> 155 // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8> 156 %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8> 157 // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d 158 // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8> 159 %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 160 %3 = mesh.shard %2 to %s3 : tensor<2xi8> 161 %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 162 %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> 163 // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> 164 %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> 165 // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : 166 // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> 167 %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding 168 %6 = mesh.shard %5 to %s6 : tensor<2xi8> 169 %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 170 %7 = mesh.shard %6 to %s7 annotate_for_users : tensor<2xi8> 171 // CHECK: return %[[RESHARD3]] : tensor<1xi8> 172 return %7 : tensor<2xi8> 173} 174 175// CHECK-LABEL: func @incomplete_sharding 176func.func @incomplete_sharding( 177 // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32> 178 %arg0: tensor<8x16xf32> 179// CHECK-SAME: -> tensor<4x16xf32> { 180) -> tensor<8x16xf32> { 181 %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 182 %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> 183 // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32> 184 %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 185 %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding 186 %2 = mesh.shard %1 to %s2 : tensor<8x16xf32> 187 // CHECK: return %[[RES]] : tensor<4x16xf32> 188 return %2 : tensor<8x16xf32> 189} 190 191mesh.mesh @mesh_1d_4(shape = 4) 192 193// CHECK-LABEL: func @ew_chain_with_halo 194func.func @ew_chain_with_halo( 195 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32> 196 %arg0: tensor<8x16xf32>) 197 // CHECK-SAME: -> tensor<5x16xf32> 198 -> tensor<8x16xf32> { 199 %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 200 %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32> 201 // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> 202 %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32> 203 %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 204 %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32> 205 %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 206 %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32> 207 // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32> 208 %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32> 209 %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 210 %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32> 211 %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 212 %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32> 213 // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32> 214 %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32> 215 %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 216 %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32> 217 %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding 218 %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32> 219 // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32> 220 return %sharding_annotated_6 : tensor<8x16xf32> 221} 222 223// CHECK-LABEL: func @test_shard_update_halo 224// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64> 225func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { 226 %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding 227 // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64> 228 // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64> 229 // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64> 230 %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> 231 %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding 232 %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> 233 %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> 234 // CHECK: return %[[UH]] : tensor<304x1200xi64> 235 return %sharding_annotated_3 : tensor<1200x1200xi64> 236} 237 238mesh.mesh @mesh4x4(shape = 4x4) 239// CHECK-LABEL: func @test_shard_update_halo2d 240// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64> 241func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> { 242 %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding 243 // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64> 244 // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64> 245 // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64> 246 %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64> 247 %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding 248 %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64> 249 %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64> 250 // CHECK: return %[[UH]] : tensor<303x307xi64> 251 return %sharding_annotated_3 : tensor<1200x1200xi64> 252}