xref: /llvm-project/mlir/test/Dialect/Mesh/canonicalization.mlir (revision ffc7feadece139c88f0e6930f16bfa9293747adc)
1// RUN: mlir-opt --canonicalize %s | FileCheck %s
2
3mesh.mesh @mesh0(shape = 2x4)
4
5// CHECK-LABEL: func @all_reduce_empty_mesh_axes
6func.func @all_reduce_empty_mesh_axes(
7// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
8    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
9// CHECK-NOT: mesh.all_reduce
10  %0 = mesh.all_reduce %arg0 on @mesh0
11    mesh_axes = []
12    : tensor<4xf32> -> tensor<4xf32>
13// CHECK: return %[[ARG]]
14  return %0 : tensor<4xf32>
15}
16
17// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
18func.func @all_reduce_empty_mesh_axes_different_return_type(
19    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
20// CHECK: mesh.all_reduce
21  %0 = mesh.all_reduce %arg0 on @mesh0
22// CHECK-NOT: mesh_axes
23    mesh_axes = []
24    : tensor<4xf32> -> tensor<4xf64>
25  return %0 : tensor<4xf64>
26}
27
28// CHECK-LABEL: func @all_reduce_default_reduction
29func.func @all_reduce_default_reduction(
30    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
31  %0 = mesh.all_reduce %arg0 on @mesh0
32    mesh_axes = [0]
33// CHECK-NOT: reduction
34    reduction = sum
35    : tensor<4xf32> -> tensor<4xf64>
36  return %0 : tensor<4xf64>
37}
38
39// CHECK-LABEL: func @all_to_all_empty_mesh_axes
40func.func @all_to_all_empty_mesh_axes(
41// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
42    %arg0 : tensor<8xf32>) -> tensor<8xf32> {
43// CHECK-NOT: mesh.all_to_all
44  %0 = mesh.all_to_all %arg0 on @mesh0
45    mesh_axes = []
46    split_axis = 0
47    concat_axis = 0
48    : tensor<8xf32> -> tensor<8xf32>
49// CHECK: return %[[ARG]]
50  return %0 : tensor<8xf32>
51}
52
53// CHECK-LABEL: func @all_gather_empty_mesh_axes
54func.func @all_gather_empty_mesh_axes(
55// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
56    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
57// CHECK-NOT: mesh.all_gather
58  %0 = mesh.all_gather %arg0 on @mesh0
59    mesh_axes = []
60    gather_axis = 0
61    : tensor<4xf32> -> tensor<4xf32>
62// CHECK: return %[[ARG]]
63  return %0 : tensor<4xf32>
64}
65
66// CHECK-LABEL: func @all_slice_empty_mesh_axes
67func.func @all_slice_empty_mesh_axes(
68// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
69    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
70// CHECK-NOT: mesh.scatter
71  %0 = mesh.all_slice %arg0 on @mesh0
72    mesh_axes = []
73    slice_axis = 0
74    : tensor<4xf32> -> tensor<4xf32>
75// CHECK: return %[[ARG]]
76  return %0 : tensor<4xf32>
77}
78
79// CHECK-LABEL: func @broadcast_empty_mesh_axes
80func.func @broadcast_empty_mesh_axes(
81// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
82    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
83// CHECK-NOT: mesh.broadcast
84  %0 = mesh.broadcast %arg0 on @mesh0
85    mesh_axes = []
86    root = []
87    : (tensor<4xf32>) -> tensor<4xf32>
88// CHECK: return %[[ARG]]
89  return %0 : tensor<4xf32>
90}
91
92// CHECK-LABEL: func @gather_empty_mesh_axes
93func.func @gather_empty_mesh_axes(
94// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
95    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
96// CHECK-NOT: mesh.gather
97  %0 = mesh.gather %arg0 on @mesh0
98    mesh_axes = []
99    gather_axis = 0
100    root = []
101    : (tensor<4xf32>) -> tensor<4xf32>
102// CHECK: return %[[ARG]]
103  return %0 : tensor<4xf32>
104}
105
106// CHECK-LABEL: func @receive_empty_mesh_axes
107func.func @receive_empty_mesh_axes(
108// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
109    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
110// CHECK-NOT: mesh.recv
111  %0 = mesh.recv %arg0 on @mesh0
112    mesh_axes = []
113    : (tensor<4xf32>) -> tensor<4xf32>
114// CHECK: return %[[ARG]]
115  return %0 : tensor<4xf32>
116}
117
118// CHECK-LABEL: func @reduce_empty_mesh_axes
119func.func @reduce_empty_mesh_axes(
120// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
121    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
122// CHECK-NOT: mesh.reduce
123  %0 = mesh.reduce %arg0 on @mesh0
124    mesh_axes = []
125    root = []
126    : (tensor<4xf32>) -> tensor<4xf32>
127// CHECK: return %[[ARG]]
128  return %0 : tensor<4xf32>
129}
130
131// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
132func.func @reduce_scatter_empty_mesh_axes(
133// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
134    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
135// CHECK-NOT: mesh.reduce_scatter
136  %0 = mesh.reduce_scatter %arg0 on @mesh0
137    mesh_axes = []
138    scatter_axis = 0
139    : tensor<4xf32> -> tensor<4xf32>
140// CHECK: return %[[ARG]]
141  return %0 : tensor<4xf32>
142}
143
144// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
145func.func @reduce_scatter_empty_mesh_axes_different_return_type(
146    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
147// CHECK: mesh.reduce_scatter
148  %0 = mesh.reduce_scatter %arg0 on @mesh0
149// CHECK-NOT: mesh_axes
150    mesh_axes = []
151    scatter_axis = 0
152    : tensor<4xf32> -> tensor<4xf64>
153  return %0 : tensor<4xf64>
154}
155
156// CHECK-LABEL: func @reduce_scatter_default_reduction
157func.func @reduce_scatter_default_reduction(
158    %arg0 : tensor<4xf32>) -> tensor<2xf64> {
159  %0 = mesh.reduce_scatter %arg0 on @mesh0
160    mesh_axes = [0]
161// CHECK-NOT: reduction
162    reduction = sum
163    scatter_axis = 0
164    : tensor<4xf32> -> tensor<2xf64>
165  return %0 : tensor<2xf64>
166}
167
168// CHECK-LABEL: func @scatter_empty_mesh_axes
169func.func @scatter_empty_mesh_axes(
170// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
171    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
172// CHECK-NOT: mesh.scatter
173  %0 = mesh.scatter %arg0 on @mesh0
174    mesh_axes = []
175    scatter_axis = 0
176    root = []
177    : (tensor<4xf32>) -> tensor<4xf32>
178// CHECK: return %[[ARG]]
179  return %0 : tensor<4xf32>
180}
181
182// CHECK-LABEL: func @send_empty_mesh_axes
183func.func @send_empty_mesh_axes(
184// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
185    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
186// CHECK-NOT: mesh.send
187  %0 = mesh.send %arg0 on @mesh0
188    mesh_axes = []
189    destination = []
190    : (tensor<4xf32>) -> tensor<4xf32>
191// CHECK: return %[[ARG]]
192  return %0 : tensor<4xf32>
193}
194
195mesh.mesh @mesh4x4(shape = 4x4)
196// CHECK-LABEL: func @test_halo_sizes
197func.func @test_halo_sizes() -> !mesh.sharding {
198  %c2_i64 = arith.constant 2 : i64
199  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
200  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
201  return %sharding : !mesh.sharding
202}
203
204// CHECK-LABEL: func @test_shard_offs
205func.func @test_shard_offs() -> !mesh.sharding {
206  %c2_i64 = arith.constant 2 : i64
207  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
208  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
209  return %sharding : !mesh.sharding
210}