1// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s 2 3mesh.mesh @mesh_2(shape = 2) 4mesh.mesh @mesh_1d(shape = ?) 5mesh.mesh @mesh_2d(shape = 2x4) 6mesh.mesh @mesh_3d(shape = ?x?x?) 7 8// CHECK-LABEL: func.func @element_wise_empty_sharding_info 9func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 10 // CHECK-NEXT: tosa.sigmoid 11 %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 12 // CHECK-NEXT: return 13 return %0 : tensor<8x16xf32> 14} 15 16// CHECK-LABEL: func.func @element_wise_on_def 17// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> 18func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 19 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 20 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> 21 // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] 22 %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 23 // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 24 // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] : tensor<8x16xf32> 25 %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding 26 %1 = mesh.shard %0 to %s1 : tensor<8x16xf32> 27 // CHECK-NEXT: return %[[V2]] 28 return %1 : tensor<8x16xf32> 29} 30 31// CHECK-LABEL: func.func @element_wise_on_use 32// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> 33func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 34 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 35 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> 36 %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding 37 %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32> 38 // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] 39 %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 40 // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 41 // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] : tensor<8x16xf32> 42 // CHECK-NEXT: return %[[V2]] 43 return %1 : tensor<8x16xf32> 44} 45 46// CHECK-LABEL: func.func @element_wise_on_graph_output 47// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> 48func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 49 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 50 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> 51 // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] 52 %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 53 // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> 54 // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 55 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] annotate_for_users : tensor<8x16xf32> 56 %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding 57 %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32> 58 // CHECK-NEXT: return %[[V3]] 59 return %1 : tensor<8x16xf32> 60} 61 62// CHECK-LABEL: func.func @element_wise_on_graph_input 63// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> 64func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 65 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 66 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32> 67 // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 68 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<8x16xf32> 69 %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding 70 %0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32> 71 // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]] 72 %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 73 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32> 74 // CHECK-NEXT: return %[[V3]] 75 return %1 : tensor<8x16xf32> 76} 77 78// CHECK-LABEL: func.func @arrow_structure 79// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> 80func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { 81 // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 82 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32> 83 // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]] 84 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32> 85 %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 86 // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32> 87 // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]] 88 // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> 89 %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 90 // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]] 91 // CHECK-NEXT: %[[S8:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 92 // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<8x16xf32> 93 %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 94 %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding 95 %3 = mesh.shard %2 to %s3 : tensor<8x16xf32> 96 // CHECK-NEXT: return %[[V6]], %[[V8]] 97 return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32> 98} 99 100// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m 101// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> 102func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { 103 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 104 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> 105 // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding 106 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> 107 // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] 108 %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> 109 // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding 110 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> 111 %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding 112 %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32> 113 // CHECK-NEXT: return %[[V3]] 114 return %1 : tensor<2x16x32xf32> 115} 116 117// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k 118// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> 119func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { 120 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding 121 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> 122 // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding 123 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> 124 // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] 125 %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> 126 // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding 127 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> 128 %s1 = mesh.sharding @mesh_2d split_axes = [[], [1]] partial = sum [0] : !mesh.sharding 129 %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32> 130 // CHECK-NEXT: return %[[V3]] 131 return %1 : tensor<2x16x32xf32> 132} 133 134// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k 135// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> 136func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { 137 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding 138 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> 139 %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding 140 %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> 141 // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding 142 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> 143 // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] 144 %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> 145 // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding 146 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> 147 // CHECK-NEXT: return %[[V3]] 148 return %1 : tensor<2x16x32xf32> 149} 150 151// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k 152// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> 153func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { 154 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding 155 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> 156 %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding 157 %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> 158 // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding 159 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> 160 %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding 161 %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> 162 // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] 163 %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> 164 // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding 165 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> 166 // CHECK-NEXT: return %[[V3]] 167 return %2 : tensor<2x16x32xf32> 168} 169 170// CHECK-LABEL: func.func @resolve_conflicting_annotations 171func.func @resolve_conflicting_annotations( 172 // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>, 173 %arg0: tensor<2x3xf32>, 174 // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>, 175 %arg1: tensor<3x2xf32>, 176 // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32> 177 %out_dps: tensor<2x2xf32> 178// CHECK-SAME: ) -> tensor<2x2xf32> { 179) -> tensor<2x2xf32> { 180 // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding 181 // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32> 182 // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding 183 // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32> 184 // CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32> 185 // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32> 186 %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding 187 %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> 188 // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>) 189 // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32> 190 %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) 191 outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> 192 // CHECK-NEXT: %[[SRES:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding 193 // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SRES]] : tensor<2x2xf32> 194 %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding 195 %res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32> 196 // CHECK: return %[[RES]] : tensor<2x2xf32> 197 return %res_sharded : tensor<2x2xf32> 198} 199 200// https://arxiv.org/abs/2211.05102 Figure 2(a) 201// CHECK-LABEL: func.func @mlp_1d_weight_stationary 202// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32> 203func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> { 204 %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding 205 %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> 206 // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding 207 // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding 208 // CHECK: %[[V0:.*]] = tosa.matmul 209 %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32> 210 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]] : tensor<2x4x32xf32> 211 // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32> 212 // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]] 213 %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> 214 // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S2]] : tensor<2x4x32xf32> 215 // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32> 216 // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0]] : !mesh.sharding 217 // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users : tensor<2x32x8xf32> 218 // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]] 219 %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32> 220 %s4 = mesh.sharding @mesh_1d split_axes = [[], [], []] partial = sum [0] : !mesh.sharding 221 %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32> 222 // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], []] partial = sum [0] : !mesh.sharding 223 // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<2x4x8xf32> 224 %s5 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding 225 %5 = mesh.shard %4 to %s5 annotate_for_users : tensor<2x4x8xf32> 226 // CHECK: %[[V9:.*]] = mesh.shard %[[V8]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32> 227 // CHECK-NEXT: return %[[V9]] 228 return %5 : tensor<2x4x8xf32> 229} 230 231// https://arxiv.org/abs/2211.05102 Figure 2(b) 232// CHECK-LABEL: func.func @mlp_2d_weight_stationary 233// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32> 234func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> { 235 // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding 236 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] : tensor<2x4x8xf32> 237 %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding 238 %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> 239 // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding 240 // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32> 241 // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding 242 // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users : tensor<2x8x32xf32> 243 // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]] 244 %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32> 245 // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding 246 // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]] : tensor<2x4x32xf32> 247 %s2 = mesh.sharding @mesh_3d split_axes = [[], [], [1, 2]] partial = sum [0] : !mesh.sharding 248 %2 = mesh.shard %1 to %s2 : tensor<2x4x32xf32> 249 // CHECK-DAG: %[[S5:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] : !mesh.sharding 250 // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32> 251 // CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]] 252 %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32> 253 // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V6]] to %[[S5]] : tensor<2x4x32xf32> 254 // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32> 255 // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding 256 // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users : tensor<2x32x8xf32> 257 // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]] 258 %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32> 259 // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding 260 // CHECK-NEXT: %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]] : tensor<2x4x8xf32> 261 %s5 = mesh.sharding @mesh_3d split_axes = [[], [], [0]] partial = sum[1, 2] : !mesh.sharding 262 %5 = mesh.shard %4 to %s5 : tensor<2x4x8xf32> 263 // CHECK-NEXT: %[[V12:.*]] = mesh.shard %[[V11]] to %[[S0]] annotate_for_users : tensor<2x4x8xf32> 264 %s6 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding 265 %6 = mesh.shard %5 to %s6 annotate_for_users : tensor<2x4x8xf32> 266 // CHECK-DAG: return %[[V12]] 267 return %6 : tensor<2x4x8xf32> 268} 269 270// CHECK-LABEL: func.func @elementwise_duplicated_chain 271// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> 272func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 273 // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding 274 // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32> 275 // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]] 276 %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 277 // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32> 278 // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32> 279 // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]] 280 %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> 281 // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding 282 // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] : tensor<8x16xf32> 283 %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding 284 %2 = mesh.shard %1 to %s0 : tensor<8x16xf32> 285 // CHECK-NEXT: return %[[V5]] 286 return %2 : tensor<8x16xf32> 287} 288