xref: /llvm-project/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (revision 8e6630391699116641cf390a10476295b7d4b95c)
1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
3func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
4    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
5    return %0 : vector<2xf32>
6}
7// CHECK-LABEL: func @vector_multi_reduction
8//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
9//   CHECK-DAG:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
10//       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
11//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0]
12//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
13//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
14//       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
15//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][1]
16//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
17//       CHECK:       %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
18//       CHECK:       return %[[RESULT_VEC]]
19
20func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
21    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
22    return %0 : f32
23}
24// CHECK-LABEL: func @vector_multi_reduction_to_scalar
25//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
26//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
27//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
28//       CHECK:   return %[[REDUCED]]
29
30func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
31    %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
32    return %0 : vector<2x3xi32>
33}
34// CHECK-LABEL: func @vector_reduction_inner
35//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
36//   CHECK-DAG:       %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
37//       CHECK:       %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
38//       CHECK:       %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<20xi32> from vector<6x20xi32>
39//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : i32 from vector<2x3xi32>
40//       CHECK:       %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32
41//       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insert %[[V0R]], %[[FLAT_RESULT_VEC_0]] [0] : i32 into vector<6xi32>
42//       CHECK:       %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<20xi32> from vector<6x20xi32>
43//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : i32 from vector<2x3xi32>
44//       CHECK:       %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32
45//       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insert %[[V1R]], %[[FLAT_RESULT_VEC_1]] [1] : i32 into vector<6xi32>
46//       CHECK:       %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<20xi32> from vector<6x20xi32>
47//       CHECK:       %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : i32 from vector<2x3xi32>
48//       CHECK:       %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32
49//       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insert %[[V2R]], %[[FLAT_RESULT_VEC_2]] [2] : i32 into vector<6xi32>
50//       CHECK:       %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<20xi32> from vector<6x20xi32>
51//       CHECK:       %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : i32 from vector<2x3xi32>
52//       CHECK:       %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32
53//       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insert %[[V3R]], %[[FLAT_RESULT_VEC_3]] [3] : i32 into vector<6xi32>
54//       CHECK:       %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<20xi32> from vector<6x20xi32>
55//       CHECK:       %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : i32 from vector<2x3xi32>
56//       CHECK:       %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32
57//       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insert %[[V4R]], %[[FLAT_RESULT_VEC_4]] [4] : i32 into vector<6xi32>
58//       CHECK:       %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<20xi32> from vector<6x20xi32>
59//       CHECK:       %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : i32 from vector<2x3xi32>
60//       CHECK:       %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32
61//       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insert %[[V5R]], %[[FLAT_RESULT_VEC_5]] [5] : i32 into vector<6xi32>
62//       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
63//       CHECK:       return %[[RESULT]]
64
65func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
66    %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
67    return %0 : vector<2x5xf32>
68}
69
70// CHECK-LABEL: func @vector_multi_reduction_transposed
71//  CHECK-SAME:    %[[INPUT:.+]]: vector<2x3x4x5xf32>
72//       CHECK:     %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
73//       CHECK:     vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
74//       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
75//       CHECK:       return %[[RESULT]]
76
77func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
78    %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
79    return %0 : vector<2x4xf32>
80}
81// CHECK-LABEL: func @vector_multi_reduction_ordering
82//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>)
83//   CHECK-DAG:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32>
84//       CHECK:       %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
85//       CHECK:       %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
86//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : f32 from vector<2x4xf32>
87//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<3xf32> into f32
88//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<8xf32>
89//       CHECK:       %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
90//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : f32 from vector<2x4xf32>
91//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<3xf32> into f32
92//       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<8xf32>
93//       CHECK:       %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
94//       CHECK:       %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : f32 from vector<2x4xf32>
95//       CHECK:       %[[RV2:.+]] = vector.reduction <mul>, %[[V2]], %[[ACC2]] : vector<3xf32> into f32
96//       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insert %[[RV2:.+]], %[[RESULT_VEC_2]] [2] : f32 into vector<8xf32>
97//       CHECK:       %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
98//       CHECK:       %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : f32 from vector<2x4xf32>
99//       CHECK:       %[[RV3:.+]] = vector.reduction <mul>, %[[V3]], %[[ACC3]] : vector<3xf32> into f32
100//       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insert %[[RV3:.+]], %[[RESULT_VEC_3]] [3] : f32 into vector<8xf32>
101//       CHECK:       %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
102//       CHECK:       %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : f32 from vector<2x4xf32>
103//       CHECK:       %[[RV4:.+]] = vector.reduction <mul>, %[[V4]], %[[ACC4]] : vector<3xf32> into f32
104//       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insert %[[RV4:.+]], %[[RESULT_VEC_4]] [4] : f32 into vector<8xf32>
105//       CHECK:       %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
106//       CHECK:       %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : f32 from vector<2x4xf32>
107//       CHECK:       %[[RV5:.+]] = vector.reduction <mul>, %[[V5]], %[[ACC5]] : vector<3xf32> into f32
108//       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insert %[[RV5:.+]], %[[RESULT_VEC_5]] [5] : f32 into vector<8xf32>
109//       CHECK:       %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
110//       CHECK:       %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : f32 from vector<2x4xf32>
111//       CHECK:       %[[RV6:.+]] = vector.reduction <mul>, %[[V6]], %[[ACC6]] : vector<3xf32> into f32
112//       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insert %[[RV6:.+]], %[[RESULT_VEC_6]] [6] : f32 into vector<8xf32>
113//       CHECK:       %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
114//       CHECK:       %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : f32 from vector<2x4xf32>
115//       CHECK:       %[[RV7:.+]] = vector.reduction <mul>, %[[V7]], %[[ACC7]] : vector<3xf32> into f32
116//       CHECK:       %[[RESULT_VEC:.+]] = vector.insert %[[RV7:.+]], %[[RESULT_VEC_7]] [7] : f32 into vector<8xf32>
117//       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
118//       CHECK:       return %[[RESHAPED_VEC]]
119
120func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
121  %c0 = arith.constant 0 : index
122  %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
123  %c1 = arith.constant 1 : index
124  %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
125  %c0_1 = arith.constant 0 : index
126  %cst = arith.constant 0.000000e+00 : f32
127  %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1>
128  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
129  %cst_2 = arith.constant 0.000000e+00 : f32
130  %2 = vector.create_mask %dim : vector<4xi1>
131  %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
132  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
133  %c0_3 = arith.constant 0 : index
134  %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
135  return %5 : tensor<?xf32>
136}
137
138// Verify that the original 2-D mask is sliced and propagated properly to the
139// vector.reduction instances.
140
141// CHECK-LABEL:   func.func @vectorize_dynamic_reduction
142// CHECK:           %[[VAL_8:.*]] = tensor.dim
143// CHECK:           %[[VAL_9:.*]] = tensor.dim
144// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1>
145
146// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<8xi1> from vector<4x8xi1>
147// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
148// CHECK:           %[[VAL_18:.*]] = vector.insert
149
150// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<8xi1> from vector<4x8xi1>
151// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
152// CHECK:           %[[VAL_23:.*]] = vector.insert
153
154// CHECK:           %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<8xi1> from vector<4x8xi1>
155// CHECK:           %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
156// CHECK:           %[[VAL_28:.*]] = vector.insert
157
158// CHECK:           %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<8xi1> from vector<4x8xi1>
159// CHECK:           %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
160// CHECK:           %[[VAL_33:.*]] = vector.insert
161
162func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
163  %c0 = arith.constant 0 : index
164  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
165  %c0_1 = arith.constant 0 : index
166  %cst = arith.constant 0.000000e+00 : f32
167  %0 = vector.create_mask %dim : vector<8xi1>
168  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor<?xf32>, vector<8xf32> } : vector<8xi1> -> vector<8xf32>
169  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32
170  return %4 : f32
171}
172
173// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction.
174// This transform expands 1-D vectors into 2-D.
175
176// CHECK-LABEL:   func.func @vectorize_1d_dynamic_reduction(
177// CHECK:           %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
178// CHECK:           %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
179
180func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
181  %c0 = arith.constant 0 : index
182  %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
183  %c1 = arith.constant 1 : index
184  %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
185  %c2 = arith.constant 2 : index
186  %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
187  %c0_2 = arith.constant 0 : index
188  %cst = arith.constant 0.000000e+00 : f32
189  %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1>
190  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
191  %cst_3 = arith.constant 0.000000e+00 : f32
192  %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1>
193  %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
194  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
195  %c0_4 = arith.constant 0 : index
196  %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
197  return %5 : tensor<?x?xf32>
198}
199
200// CHECK-LABEL:   func.func @vectorize_dynamic_transpose_reduction
201// CHECK:           %[[VAL_6:.*]] = tensor.dim
202// CHECK:           %[[VAL_7:.*]] = tensor.dim
203// CHECK:           %[[VAL_8:.*]] = tensor.dim
204// CHECK:           %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1>
205// CHECK:           %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1>
206
207// Just checking a few instances to make sure the vector mask is properly propagated:
208
209// CHECK:           %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<4xi1> from vector<8x16x4xi1>
210// CHECK:           %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add>
211// CHECK:           %[[VAL_145:.*]] = vector.insert %[[VAL_144]]
212
213// CHECK:           %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<4xi1> from vector<8x16x4xi1>
214// CHECK:           %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add>
215// CHECK:           %[[VAL_150:.*]] = vector.insert %[[VAL_149]]
216
217// CHECK:           %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<4xi1> from vector<8x16x4xi1>
218// CHECK:           %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add>
219// CHECK:           %[[VAL_155:.*]] = vector.insert %[[VAL_154]]
220
221// CHECK:           %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<4xi1> from vector<8x16x4xi1>
222// CHECK:           %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
223// CHECK:           %[[VAL_160:.*]] = vector.insert %[[VAL_159]]
224
225func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
226    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
227    return %0 : vector<4xf32>
228}
229
230// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
231//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
232//       CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
233
234func.func private @vector_multi_reduction_non_scalable_dim(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
235  %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
236  return %0 : vector<8x[4]xf32>
237}
238// CHECK-LABEL:   func.func private @vector_multi_reduction_non_scalable_dim(
239// CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
240// CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
241// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
242
243// CHECK:           %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2xf32> from vector<8x[4]x2xf32>
244// CHECK:           %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : f32 from vector<8x[4]xf32>
245// CHECK:           %[[VAL_37:.*]] = vector.reduction <add>, %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32
246// CHECK:           %[[VAL_38:.*]] = vector.insert %[[VAL_37]], %[[VAL_2]] [0] : f32 into vector<[32]xf32>
247
248// CHECK:           %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2xf32> from vector<8x[4]x2xf32>
249// CHECK:           %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : f32 from vector<8x[4]xf32>
250// CHECK:           %[[VAL_41:.*]] = vector.reduction <add>, %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32
251// CHECK:           %[[VAL_42:.*]] = vector.insert %[[VAL_41]], %[[VAL_38]] [1] : f32 into vector<[32]xf32>
252
253// (...)
254
255// CHECK:           %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<2xf32> from vector<8x[4]x2xf32>
256// CHECK:           %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : f32 from vector<8x[4]xf32>
257// CHECK:           %[[VAL_161:.*]] = vector.reduction <add>, %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32
258// CHECK:           %[[VAL_162:.*]] = vector.insert %[[VAL_161]], %{{.*}} [31] : f32 into vector<[32]xf32>
259
260// CHECK:           %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
261// CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
262
263// Check that OneDimMultiReductionToTwoDim handles scalable dim
264func.func @vector_multi_reduction_scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
265    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
266    return %0 : f32
267}
268
269// CHECK-LABEL:  func.func @vector_multi_reduction_scalable_dim_1d(
270// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<[4]xf32>,
271// CHECK-SAME:                                      %[[ARG_1:.*]]: f32,
272// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
273// CHECK:          %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
274// CHECK:          return %[[VAL_2]] : f32
275
276func.func @vector_multi_reduction_scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
277    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
278    return %0 : vector<2xf32>
279}
280
281// CHECK-LABEL:  func.func @vector_multi_reduction_scalable_dim_2d(
282// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<2x[4]xf32>,
283// CHECK-SAME:                                      %[[ARG_1:.*]]: vector<2xf32>,
284// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
285// CHECK-DAG:      %[[C0_2xf32:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
286// CHECK:          %[[ARG0_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
287// CHECK:          %[[ARG1_0:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
288// CHECK:          %[[ARG2_0:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
289// CHECK:          %[[REDUCE_0:.*]] = vector.mask %[[ARG2_0]] { vector.reduction <add>, %[[ARG0_0]], %[[ARG1_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
290// CHECK:          %[[INSERT_0:.*]] = vector.insert %[[REDUCE_0]], %[[C0_2xf32]] [0] : f32 into vector<2xf32>
291// CHECK:          %[[ARG0_1:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
292// CHECK:          %[[ARG1_1:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
293// CHECK:          %[[ARG2_1:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
294// CHECK:          %[[REDUCE_1:.*]] = vector.mask %[[ARG2_1]] { vector.reduction <add>, %[[ARG0_1]], %[[ARG1_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
295// CHECK:          %[[INSERT_1:.*]] = vector.insert %[[REDUCE_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
296// CHECK:          return %[[INSERT_1]] : vector<2xf32>
297
298module attributes {transform.with_named_sequence} {
299  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
300    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
301    transform.apply_patterns to %func_op {
302      transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
303    } : !transform.op<"func.func">
304    transform.yield
305  }
306}
307