1// RUN: mlir-opt \ 2// RUN: --verify-each \ 3// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \ 4// RUN: %s | FileCheck %s 5 6mesh.mesh @mesh_2(shape = 2) 7 8// CHECK-LABEL: func @matmul_shard_prallel_axis 9func.func @matmul_shard_prallel_axis( 10 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>, 11 %arg0 : tensor<2x3xf32>, 12 // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>, 13 %arg1 : tensor<3x2xf32>, 14 // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32> 15 %out_dps: tensor<2x2xf32> 16) -> tensor<2x2xf32> { 17 // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding 18 // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32> 19 // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding 20 // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32> 21 // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding 22 // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32> 23 // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding 24 // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32> 25 %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding 26 %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32> 27 28 // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>) 29 // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32> 30 %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>) 31 outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32> 32 33 // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding 34 // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32> 35 // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding 36 // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32> 37 %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding 38 %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32> 39 40 // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32> 41 return %res_sharded : tensor<2x2xf32> 42} 43