xref: /llvm-project/mlir/test/Dialect/Mesh/simplifications.mlir (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
2
3mesh.mesh @mesh0(shape = 4x2)
4mesh.mesh @mesh1(shape = 4)
5
6// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
7// `all_reduce(x + y)`.
8// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
9func.func @all_reduce_arith_addf_endomorphism(
10    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
11    %arg0: tensor<5xf32>,
12    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
13    %arg1: tensor<5xf32>) -> tensor<5xf32> {
14  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
15    : tensor<5xf32> -> tensor<5xf32>
16  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
17    : tensor<5xf32> -> tensor<5xf32>
18  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
19  %2 = arith.addf %0, %1 : tensor<5xf32>
20  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
21  // CHECK: return %[[ALL_REDUCE_RES]]
22  return %2 : tensor<5xf32>
23}
24
25// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result
26func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
27    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
28    %arg0: tensor<5xf32>,
29    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
30    %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
31  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
32    : tensor<5xf32> -> tensor<5xf32>
33  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
34    : tensor<5xf32> -> tensor<5xf32>
35  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
36  %2 = arith.addf %0, %1 : tensor<5xf32>
37  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
38  // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
39  return %2, %2 : tensor<5xf32>, tensor<5xf32>
40}
41
42// Do not simplify if there is another use of one of the all-reduces.
43// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
44func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result(
45    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
46    %arg0: tensor<5xf32>,
47    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
48    %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
49  // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
50  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
51    : tensor<5xf32> -> tensor<5xf32>
52  // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
53  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
54    : tensor<5xf32> -> tensor<5xf32>
55  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
56  %2 = arith.addf %0, %1 : tensor<5xf32>
57  // CHECK: return %[[ALL_REDUCE_0_RES]], %[[ADD_RES]]
58  return %0, %2 : tensor<5xf32>, tensor<5xf32>
59}
60
61// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
62func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
63    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
64    %arg0: tensor<5xf32>,
65    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
66    %arg1: tensor<5xf32>) -> tensor<5xf32> {
67  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
68  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
69    : tensor<5xf32> -> tensor<5xf32>
70  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
71  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
72    : tensor<5xf32> -> tensor<5xf32>
73  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
74  %2 = arith.addf %0, %1 : tensor<5xf32>
75  // CHECK: return %[[ADD_RES]]
76  return %2 : tensor<5xf32>
77}
78
79// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
80func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
81    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
82    %arg0: tensor<5xf32>,
83    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
84    %arg1: tensor<5xf32>) -> tensor<5xf32> {
85  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
86  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
87    : tensor<5xf32> -> tensor<5xf32>
88  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
89  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
90    : tensor<5xf32> -> tensor<5xf32>
91  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
92  %2 = arith.addf %0, %1 : tensor<5xf32>
93  // CHECK: return %[[ADD_RES]]
94  return %2 : tensor<5xf32>
95}
96
97// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
98func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
99    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
100    %arg0: tensor<5xf32>,
101    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
102    %arg1: tensor<5xf32>) -> tensor<5xf32> {
103  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
104  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
105    : tensor<5xf32> -> tensor<5xf32>
106  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
107  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
108    : tensor<5xf32> -> tensor<5xf32>
109  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
110  %2 = arith.addf %0, %1 : tensor<5xf32>
111  // CHECK: return %[[ADD_RES]]
112  return %2 : tensor<5xf32>
113}
114
115// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
116func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
117    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
118    %arg0: tensor<5xf32>,
119    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
120    %arg1: tensor<5xf32>) -> tensor<5xf64> {
121  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
122  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
123    : tensor<5xf32> -> tensor<5xf64>
124  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
125  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
126    : tensor<5xf32> -> tensor<5xf64>
127  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
128  %2 = arith.addf %0, %1 : tensor<5xf64>
129  // CHECK: return %[[ADD_RES]]
130  return %2 : tensor<5xf64>
131}
132
133// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
134// `all_reduce(min(x, y))`.
135// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
136func.func @all_reduce_arith_minimumf_endomorphism(
137    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
138    %arg0: tensor<5xf32>,
139    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
140    %arg1: tensor<5xf32>) -> tensor<5xf32> {
141  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
142    : tensor<5xf32> -> tensor<5xf32>
143  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
144    : tensor<5xf32> -> tensor<5xf32>
145  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
146  %2 = arith.minimumf %0, %1 : tensor<5xf32>
147  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
148  // CHECK: return %[[ALL_REDUCE_RES]]
149  return %2 : tensor<5xf32>
150}
151
152// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
153func.func @all_reduce_arith_minsi_endomorphism(
154    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
155    %arg0: tensor<5xi32>,
156    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
157    %arg1: tensor<5xi32>) -> tensor<5xi32> {
158  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
159    : tensor<5xi32> -> tensor<5xi32>
160  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
161    : tensor<5xi32> -> tensor<5xi32>
162  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
163  %2 = arith.minsi %0, %1 : tensor<5xi32>
164  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
165  // CHECK: return %[[ALL_REDUCE_RES]]
166  return %2 : tensor<5xi32>
167}
168