xref: /llvm-project/mlir/test/Dialect/Mesh/spmdization.mlir (revision 79eb406a67fe08458548289da72cda18248a9313)
1// RUN: mlir-opt \
2// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
3// RUN:   %s | FileCheck %s
4
5mesh.mesh @mesh_1d(shape = 2)
6
7// CHECK-LABEL: func @full_replication
8func.func @full_replication(
9  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
10  %arg0: tensor<2xi8>
11// CHECK-SAME: -> tensor<2xi8> {
12) -> tensor<2xi8> {
13  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
14  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
15  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
16  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
17  // CHECK: return %[[ARG]] : tensor<2xi8>
18  return %1 : tensor<2xi8>
19}
20
21// CHECK-LABEL: func @sharding_triplet
22func.func @sharding_triplet(
23  // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
24  %arg0: tensor<2xf32>
25// CHECK-SAME: ) -> tensor<2xf32> {
26) -> tensor<2xf32> {
27  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
28  %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
29  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  : tensor<2xf32>
30  %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
31  %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0  annotate_for_users : tensor<2xf32>
32  %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
33  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  : tensor<2xf32>
34  // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
35  return %sharding_annotated_1 : tensor<2xf32>
36}
37
38
39// CHECK-LABEL: func @move_split_axis
40func.func @move_split_axis(
41  // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
42  %arg0: tensor<2x2xi8>
43// CHECK-SAME: -> tensor<2x1xi8> {
44) -> tensor<2x2xi8> {
45  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
46  // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
47  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
48  %0 = mesh.shard %arg0 to %s0  : tensor<2x2xi8>
49  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
50  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2x2xi8>
51  // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
52  return %1 : tensor<2x2xi8>
53}
54
55// CHECK-LABEL: func @non_tensor_value
56func.func @non_tensor_value(
57  // CHECK-SAME: %[[ARG:.*]]: i8
58  %arg0: i8
59// CHECK-SAME: -> i8 {
60) -> i8 {
61  // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
62  %0 = arith.addi %arg0, %arg0 : i8
63  // CHECK: return %[[RES]] : i8
64  return %0 : i8
65}
66
67// CHECK-LABEL: func @unary_elementwise
68func.func @unary_elementwise(
69  // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
70  %arg0: tensor<2xi8>
71// CHECK-SAME: -> tensor<1xi8> {
72) -> tensor<2xi8> {
73  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
74  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
75  %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
76  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
77  // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
78  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
79  %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
80  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
81  %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
82  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
83  // CHECK: return %[[RES]] : tensor<1xi8>
84  return %4 : tensor<2xi8>
85}
86
87// full replication -> shard axis -> abs -> shard axis -> full replication
88// CHECK-LABEL: func @unary_elementwise_with_resharding
89func.func @unary_elementwise_with_resharding(
90  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
91  %arg0: tensor<2xi8>
92// CHECK-SAME: -> tensor<2xi8> {
93) -> tensor<2xi8> {
94  // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
95  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
96  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
97  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
98  %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
99  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
100  // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
101  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
102  // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
103  // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
104  %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
105  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
106  %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
107  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
108  // CHECK: return %[[RES]] : tensor<2xi8>
109  return %4 : tensor<2xi8>
110}
111
112// CHECK-LABEL: func @binary_elementwise
113func.func @binary_elementwise(
114  // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
115  %arg0: tensor<2xi8>,
116  // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
117  %arg1: tensor<2xi8>
118// CHECK-SAME: -> tensor<1xi8> {
119) -> tensor<2xi8> {
120  %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
121  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded  : tensor<2xi8>
122  %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
123  %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0  annotate_for_users : tensor<2xi8>
124  %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
125  %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded  : tensor<2xi8>
126  %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
127  %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1  annotate_for_users : tensor<2xi8>
128  // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
129  %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
130  %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
131  %op_res_sharded = mesh.shard %op_res to %sop_res_sharded  : tensor<2xi8>
132  %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
133  %res = mesh.shard %op_res_sharded to %sres  annotate_for_users : tensor<2xi8>
134  // CHECK: return %[[RES]] : tensor<1xi8>
135  return %res : tensor<2xi8>
136}
137
138// reshard
139// abs
140// reshard
141// abs
142// reshard
143// CHECK-LABEL: func @multiple_chained_ops
144func.func @multiple_chained_ops(
145  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
146  %arg0: tensor<2xi8>
147// CHECK-SAME: -> tensor<1xi8> {
148) -> tensor<2xi8> {
149  // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
150  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
151  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
152  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
153  %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
154  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
155  // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
156  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
157  // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
158  // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
159  %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
160  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
161  %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
162  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
163  // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
164  %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
165  // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
166  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
167  %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
168  %6 = mesh.shard %5 to %s6  : tensor<2xi8>
169  %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
170  %7 = mesh.shard %6 to %s7  annotate_for_users : tensor<2xi8>
171  // CHECK: return %[[RESHARD3]] : tensor<1xi8>
172  return %7 : tensor<2xi8>
173}
174
175// CHECK-LABEL: func @incomplete_sharding
176func.func @incomplete_sharding(
177  // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
178  %arg0: tensor<8x16xf32>
179// CHECK-SAME: -> tensor<4x16xf32> {
180) -> tensor<8x16xf32> {
181  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
182  %0 = mesh.shard %arg0 to %s0  annotate_for_users : tensor<8x16xf32>
183  // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
184  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
185  %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
186  %2 = mesh.shard %1 to %s2  : tensor<8x16xf32>
187  // CHECK: return %[[RES]] : tensor<4x16xf32>
188  return %2 : tensor<8x16xf32>
189}
190
191mesh.mesh @mesh_1d_4(shape = 4)
192
193// CHECK-LABEL: func @ew_chain_with_halo
194func.func @ew_chain_with_halo(
195  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
196  %arg0: tensor<8x16xf32>)
197  // CHECK-SAME: -> tensor<5x16xf32>
198   -> tensor<8x16xf32> {
199  %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
200  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  annotate_for_users : tensor<8x16xf32>
201  // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
202  %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
203  %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
204  %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0  : tensor<8x16xf32>
205  %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
206  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  annotate_for_users : tensor<8x16xf32>
207  // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
208  %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
209  %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
210  %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2  : tensor<8x16xf32>
211  %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
212  %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4  annotate_for_users : tensor<8x16xf32>
213  // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
214  %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32>
215  %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
216  %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5  : tensor<8x16xf32>
217  %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
218  %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6  annotate_for_users : tensor<8x16xf32>
219  // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
220  return %sharding_annotated_6 : tensor<8x16xf32>
221}
222
223// CHECK-LABEL: func @test_shard_update_halo
224// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
225func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
226  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
227  // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
228  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
229  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
230  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
231  %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
232  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
233  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
234  // CHECK: return %[[UH]] : tensor<304x1200xi64>
235  return %sharding_annotated_3 : tensor<1200x1200xi64>
236}
237
238mesh.mesh @mesh4x4(shape = 4x4)
239// CHECK-LABEL: func @test_shard_update_halo2d
240// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
241func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
242  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
243  // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
244  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
245  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
246  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
247  %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
248  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
249  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
250  // CHECK: return %[[UH]] : tensor<303x307xi64>
251  return %sharding_annotated_3 : tensor<1200x1200xi64>
252}