xref: /llvm-project/mlir/test/Dialect/Mesh/resharding-spmdization.mlir (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
2
3mesh.mesh @mesh_1d(shape = 2)
4mesh.mesh @mesh_1d_dynamic(shape = ?)
5
6// CHECK-LABEL: func @same_source_and_target_sharding
7func.func @same_source_and_target_sharding(
8  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
9  %arg0: tensor<2xf32>
10) -> tensor<2xf32> {
11  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
12  %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
13  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
14  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
15  // CHECK: return %[[ARG]]
16  return %1 : tensor<2xf32>
17}
18
19// CHECK-LABEL: func @identical_source_and_target_sharding
20func.func @identical_source_and_target_sharding(
21  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
22  %arg0: tensor<2xf32>
23) -> tensor<2xf32> {
24  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
25  %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
26  %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
27  // CHECK: return %[[ARG]]
28  return %1 : tensor<2xf32>
29}
30
31// CHECK-LABEL: func @split_replicated_tensor_axis
32func.func @split_replicated_tensor_axis(
33  // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
34  %arg0: tensor<3x14xf32>
35) -> tensor<3x14xf32> {
36  // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
37  // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
38  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
39  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
40  %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
41  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
42  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
43  // CHECK: return %[[RESULT]] : tensor<3x14xf32>
44  return %1 : tensor<3x14xf32>
45}
46
47// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
48func.func @split_replicated_tensor_axis_dynamic(
49  // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
50  %arg0: tensor<?x3x?xf32>
51) -> tensor<?x3x?xf32> {
52  // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
53  // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
54  %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding
55  %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
56  %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
57  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
58  // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
59  return %1 : tensor<?x3x?xf32>
60}
61
62// CHECK-LABEL: func @move_split_axis
63func.func @move_split_axis(
64  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
65  %arg0: tensor<10x14xf32>
66) -> tensor<10x14xf32> {
67  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
68  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
69  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
70  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
71  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
72  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
73  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
74  // CHECK: return %[[RES]] : tensor<10x14xf32>
75  return %1 : tensor<10x14xf32>
76}
77
78// CHECK-LABEL: func @move_split_axis_dynamic_mesh
79func.func @move_split_axis_dynamic_mesh(
80  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
81  %arg0: tensor<10x14xf32>
82) -> tensor<10x14xf32> {
83  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
84  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
85  // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
86  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
87  %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
88  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
89  %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding
90  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
91  // CHECK: return %[[RES]] : tensor<10x14xf32>
92  return %1 : tensor<10x14xf32>
93}
94
95// CHECK-LABEL: func @move_split_dynamic_axis
96func.func @move_split_dynamic_axis(
97  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
98  %arg0: tensor<?x14xf32>
99) -> tensor<?x14xf32> {
100  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
101  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
102  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
103  %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
104  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
105  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
106  // CHECK: return %[[RES]] : tensor<?x14xf32>
107  return %1 : tensor<?x14xf32>
108}
109
110// CHECK-LABEL: func @unshard_static_axis
111func.func @unshard_static_axis(
112  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
113  %arg0: tensor<10x14xf32>
114) -> tensor<10x14xf32> {
115  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
116  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
117  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
118  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
119  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
120  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
121  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
122  return %1 : tensor<10x14xf32>
123}
124
125// CHECK-LABEL: func @unshard_static_last_axis
126func.func @unshard_static_last_axis(
127  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
128  %arg0: tensor<10x14xf32>
129) -> tensor<10x14xf32> {
130  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
131  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
132  %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
133  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
134  %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
135  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
136  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
137  return %1 : tensor<10x14xf32>
138}
139
140// CHECK-LABEL: func @unshard_dynamic_axis
141func.func @unshard_dynamic_axis(
142  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
143  %arg0: tensor<?x14xf32>
144) -> tensor<?x14xf32> {
145  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
146  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
147  %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
148  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
149  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
150  // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
151  return %1 : tensor<?x14xf32>
152}
153
154// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
155func.func @unshard_static_axis_on_dynamic_mesh_axis(
156// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
157  %arg0: tensor<10x14xf32>
158) -> tensor<10x14xf32> {
159  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
160  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
161  // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
162  %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
163  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
164  %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding
165  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
166  // CHECK: return %[[RES]] : tensor<10x14xf32>
167  return %1 : tensor<10x14xf32>
168}
169
170// CHECK-LABEL: func @partial_axis_to_full_replication
171func.func @partial_axis_to_full_replication(
172// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
173  %arg0: tensor<10x14xf32>
174) -> tensor<10x14xf32> {
175  // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
176  %s0 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
177  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
178  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
179  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
180  // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
181  return %1 : tensor<10x14xf32>
182}
183