xref: /llvm-project/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir (revision 2f664f2bdf2a3974f5c35ea0278239011181eb87)
1// RUN: mlir-opt \
2// RUN:   --verify-each \
3// RUN:   --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
4// RUN:   %s | FileCheck %s
5
6mesh.mesh @mesh_2(shape = 2)
7
8// CHECK-LABEL: func @matmul_shard_prallel_axis
9func.func @matmul_shard_prallel_axis(
10  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
11  %arg0 : tensor<2x3xf32>,
12  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
13  %arg1 : tensor<3x2xf32>,
14  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
15  %out_dps: tensor<2x2xf32>
16) -> tensor<2x2xf32> {
17  // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
18  // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
19  // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
20  // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
21  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
22  // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
23  // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
24  // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
25  %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
26  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
27
28  // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
29  // CHECK-SAME:  outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
30  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
31    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
32
33  // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
34  // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
35  // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
36  // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
37  %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
38  %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
39
40  // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
41  return %res_sharded : tensor<2x2xf32>
42}
43