xref: /llvm-project/mlir/test/Dialect/Linalg/mesh-spmdization.mlir (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1// RUN: mlir-opt \
2// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
3// RUN:  --split-input-file \
4// RUN:  %s | FileCheck %s
5
6// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
7#map_identity_1d = affine_map<(d0) -> (d0)>
8
9mesh.mesh @mesh_1d(shape = 2)
10
11// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
12func.func @elementwise_static_1d_mesh_static_1d_tensor(
13  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
14  %in1: tensor<2xi8>,
15  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
16  %in2: tensor<2xi8>,
17  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8>
18  %dps_out: tensor<2xi8>
19// CHECK-SAME: -> tensor<1xi8> {
20) -> tensor<2xi8> {
21  %sharding = mesh.sharding @mesh_1d split_axes = [[0]]  : !mesh.sharding
22  %in1_sharded1 = mesh.shard %in1 to %sharding  : tensor<2xi8>
23  %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
24  %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8>
25  %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
26  %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8>
27  %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
28  // CHECK: %[[RES:.*]] = linalg.generic {
29  // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
30  // CHECK-SAME: iterator_types = ["parallel"]}
31  // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>)
32  // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) {
33  %res = linalg.generic {
34      indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d],
35      iterator_types = ["parallel"]
36    } ins(%in1_sharded2, %in2_sharded2 : tensor<2xi8>, tensor<2xi8>)
37      outs(%dps_out_shared2 : tensor<2xi8>) {
38    ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8):
39      %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
40      linalg.yield %res_scalar : i8
41    } -> tensor<2xi8>
42  %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8>
43  %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
44  // CHECK: return %[[RES]] : tensor<1xi8>
45  return %res_shared2 : tensor<2xi8>
46}
47
48// -----
49
50mesh.mesh @mesh_1d(shape = 4)
51
52// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
53func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
54  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
55  %in1: tensor<4x3xi8>,
56// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
57  %in2: tensor<3x8xi8>,
58// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8>
59  %dps_out: tensor<4x8xi8>
60// CHECK-SAME: -> tensor<1x8xi8> {
61) -> tensor<4x8xi8> {
62  %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
63  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8>
64  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
65  %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
66  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8>
67  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
68  %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8>
69  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
70  // CHECK: %[[RES:.*]] = linalg.matmul
71  // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
72  // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
73  // CHECK-SAME: -> tensor<1x8xi8>
74  %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
75      outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
76  %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8>
77  %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
78  // CHECK: return %[[RES]] : tensor<1x8xi8>
79  return %res_shared2 : tensor<4x8xi8>
80}
81
82// -----
83
84mesh.mesh @mesh_1d(shape = 3)
85
86// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
87func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
88  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
89  %in1: tensor<4x6xi8>,
90// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
91  %in2: tensor<6x8xi8>,
92// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
93  %dps_out: tensor<4x8xi8>
94// CHECK-SAME: -> tensor<4x8xi8> {
95) -> tensor<4x8xi8> {
96  %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
97  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
98  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
99  %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
100  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
101  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
102  %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
103  %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
104  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
105  // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
106  // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
107  // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
108  // CHECK-DAG:  %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
109  // CHECK:      %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
110  // CHECK:      %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
111  // CHECK:        scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
112  // CHECK:      } else {
113  // CHECK-DAG:    %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
114  // CHECK:        %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
115  // CHECK-SAME:       outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
116  // CHECK:        scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
117  // CHECK:      }
118  // CHECK:      %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
119  // CHECK-SAME:     outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
120  // CHECK:      %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
121  %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
122      outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
123  %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8>
124  %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
125  // CHECK:      return %[[ALL_REDUCED]] : tensor<4x8xi8>
126  return %res_shared2 : tensor<4x8xi8>
127}
128
129// -----
130
131mesh.mesh @mesh_1d(shape = 3)
132
133// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result
134func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result(
135  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
136  %in1: tensor<4x6xi8>,
137// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
138  %in2: tensor<6x8xi8>,
139// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
140  %dps_out: tensor<4x8xi8>
141// CHECK-SAME: -> tensor<4x8xi8> {
142) -> tensor<4x8xi8> {
143  %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
144  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
145  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
146  %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
147  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
148  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
149  %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
150  %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
151  %sdps_out_shared2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
152  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
153  // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
154  // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
155  // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
156  // CHECK-DAG:  %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
157  // CHECK:      %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
158  // CHECK:      %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
159  // CHECK:        scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
160  // CHECK:      } else {
161  // CHECK-DAG:    %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
162  // CHECK:        %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
163  // CHECK-SAME:       outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
164  // CHECK:        scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
165  // CHECK:      }
166  // CHECK:      %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
167  // CHECK-SAME:     outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
168  %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
169      outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
170  %sharding4 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
171  %res_shared1 = mesh.shard %res to %sharding4 : tensor<4x8xi8>
172  %res_shared2 = mesh.shard %res_shared1 to %sharding4 annotate_for_users : tensor<4x8xi8>
173  // CHECK:      return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
174  return %res_shared2 : tensor<4x8xi8>
175}
176
177// -----
178
179mesh.mesh @mesh_1d(shape = 4)
180
181// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
182func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
183  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
184  %in1: tensor<4x6xi8>,
185  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
186  %in2: tensor<6x8xi8>,
187  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
188  %dps_out: tensor<4x8xi8>
189  // CHECK-SAME: -> tensor<4x8xi8> {
190) -> tensor<4x8xi8> {
191  %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
192  %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8>
193  %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
194  // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
195  %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8>
196  %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
197  %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
198  // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
199  %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8>
200  %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
201  // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
202  // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
203  // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
204  // CHECK-SAME: -> tensor<4x2xi8>
205  %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
206      outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
207  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
208  %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8>
209  %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
210  // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
211  return %res_replicated : tensor<4x8xi8>
212}
213