xref: /llvm-project/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
3func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
4    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
5    return %0 : vector<2xf32>
6}
7
8// CHECK-LABEL: func @vector_multi_reduction
9//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
10//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
11//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
12//       CHECK:   %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
13//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
14//       CHECK:   %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
15//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
16//       CHECK:   %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
17//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
18//       CHECK:   %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
19//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
20
21func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
22    %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
23    return %0 : vector<2xf32>
24}
25
26// CHECK-LABEL: func @vector_multi_reduction_min
27//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
28//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
29//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
30//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
31//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
32//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
33//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
34//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
35//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
36//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
37//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
38
39func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
40    %0 = vector.multi_reduction <maxnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
41    return %0 : vector<2xf32>
42}
43
44// CHECK-LABEL: func @vector_multi_reduction_max
45//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
46//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
47//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
48//       CHECK:   %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
49//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
50//       CHECK:   %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
51//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
52//       CHECK:   %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
53//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
54//       CHECK:   %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
55//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
56
57func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
58    %0 = vector.multi_reduction <and>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
59    return %0 : vector<2xi32>
60}
61
62// CHECK-LABEL: func @vector_multi_reduction_and
63//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
64//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
65//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32>
66//       CHECK:   %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32>
67//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32>
68//       CHECK:   %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32>
69//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32>
70//       CHECK:   %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32>
71//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32>
72//       CHECK:   %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32>
73//       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
74
75func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
76    %0 = vector.multi_reduction <or>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
77    return %0 : vector<2xi32>
78}
79
80// CHECK-LABEL: func @vector_multi_reduction_or
81//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
82//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
83//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32>
84//       CHECK:   %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32>
85//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32>
86//       CHECK:   %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32>
87//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32>
88//       CHECK:   %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32>
89//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32>
90//       CHECK:   %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32>
91//       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
92
93func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
94    %0 = vector.multi_reduction <xor>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
95    return %0 : vector<2xi32>
96}
97
98// CHECK-LABEL: func @vector_multi_reduction_xor
99//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
100//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
101//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32>
102//       CHECK:   %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32>
103//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32>
104//       CHECK:   %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32>
105//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32>
106//       CHECK:   %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32>
107//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32>
108//       CHECK:   %[[RESULT_VEC:.+]] = arith.xori %[[V3]], %[[RV012]] : vector<2xi32>
109//       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
110
111
112func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
113    %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
114    return %0 : vector<2x3xi32>
115}
116
117// CHECK-LABEL: func @vector_reduction_outer
118//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
119//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
120//       CHECK:   %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
121//       CHECK:   %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
122//       CHECK:   %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<6xi32> from vector<20x6xi32>
123//       CHECK:   %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32>
124//       CHECK:   %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<6xi32> from vector<20x6xi32>
125//       CHECK:   %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32>
126//       CHECK:   %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<6xi32> from vector<20x6xi32>
127//       CHECK:   %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32>
128//       CHECK:   %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<6xi32> from vector<20x6xi32>
129//       CHECK:   %[[R2:.+]] = arith.addi %[[V3]], %[[R1]] : vector<6xi32>
130//       CHECK:   %[[V4:.+]] = vector.extract %[[RESHAPED]][4] : vector<6xi32> from vector<20x6xi32>
131//       CHECK:   %[[R3:.+]] = arith.addi %[[V4]], %[[R2]] : vector<6xi32>
132//       CHECK:   %[[V5:.+]] = vector.extract %[[RESHAPED]][5] : vector<6xi32> from vector<20x6xi32>
133//       CHECK:   %[[R4:.+]] = arith.addi %[[V5]], %[[R3]] : vector<6xi32>
134//       CHECK:   %[[V6:.+]] = vector.extract %[[RESHAPED]][6] : vector<6xi32> from vector<20x6xi32>
135//       CHECK:   %[[R5:.+]] = arith.addi %[[V6]], %[[R4]] : vector<6xi32>
136//       CHECK:   %[[V7:.+]] = vector.extract %[[RESHAPED]][7] : vector<6xi32> from vector<20x6xi32>
137//       CHECK:   %[[R6:.+]] = arith.addi %[[V7]], %[[R5]] : vector<6xi32>
138//       CHECK:   %[[V8:.+]] = vector.extract %[[RESHAPED]][8] : vector<6xi32> from vector<20x6xi32>
139//       CHECK:   %[[R7:.+]] = arith.addi %[[V8]], %[[R6]] : vector<6xi32>
140//       CHECK:   %[[V9:.+]] = vector.extract %[[RESHAPED]][9] : vector<6xi32> from vector<20x6xi32>
141//       CHECK:   %[[R8:.+]] = arith.addi %[[V9]], %[[R7]] : vector<6xi32>
142//       CHECK:   %[[V10:.+]] = vector.extract %[[RESHAPED]][10] : vector<6xi32> from vector<20x6xi32>
143//       CHECK:   %[[R9:.+]] = arith.addi %[[V10]], %[[R8]] : vector<6xi32>
144//       CHECK:   %[[V11:.+]] = vector.extract %[[RESHAPED]][11] : vector<6xi32> from vector<20x6xi32>
145//       CHECK:   %[[R10:.+]] = arith.addi %[[V11]], %[[R9]] : vector<6xi32>
146//       CHECK:   %[[V12:.+]] = vector.extract %[[RESHAPED]][12] : vector<6xi32> from vector<20x6xi32>
147//       CHECK:   %[[R11:.+]] = arith.addi %[[V12]], %[[R10]] : vector<6xi32>
148//       CHECK:   %[[V13:.+]] = vector.extract %[[RESHAPED]][13] : vector<6xi32> from vector<20x6xi32>
149//       CHECK:   %[[R12:.+]] = arith.addi %[[V13]], %[[R11]] : vector<6xi32>
150//       CHECK:   %[[V14:.+]] = vector.extract %[[RESHAPED]][14] : vector<6xi32> from vector<20x6xi32>
151//       CHECK:   %[[R13:.+]] = arith.addi %[[V14]], %[[R12]] : vector<6xi32>
152//       CHECK:   %[[V15:.+]] = vector.extract %[[RESHAPED]][15] : vector<6xi32> from vector<20x6xi32>
153//       CHECK:   %[[R14:.+]] = arith.addi %[[V15]], %[[R13]] : vector<6xi32>
154//       CHECK:   %[[V16:.+]] = vector.extract %[[RESHAPED]][16] : vector<6xi32> from vector<20x6xi32>
155//       CHECK:   %[[R15:.+]] = arith.addi %[[V16]], %[[R14]] : vector<6xi32>
156//       CHECK:   %[[V17:.+]] = vector.extract %[[RESHAPED]][17] : vector<6xi32> from vector<20x6xi32>
157//       CHECK:   %[[R16:.+]] = arith.addi %[[V17]], %[[R15]] : vector<6xi32>
158//       CHECK:   %[[V18:.+]] = vector.extract %[[RESHAPED]][18] : vector<6xi32> from vector<20x6xi32>
159//       CHECK:   %[[R17:.+]] = arith.addi %[[V18]], %[[R16]] : vector<6xi32>
160//       CHECK:   %[[V19:.+]] = vector.extract %[[RESHAPED]][19] : vector<6xi32> from vector<20x6xi32>
161//       CHECK:   %[[R18:.+]] = arith.addi %[[V19]], %[[R17]] : vector<6xi32>
162//       CHECK:   %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
163//       CHECK:   return %[[RESULT_VEC]] : vector<2x3xi32>
164
165func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
166    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
167    return %0 : vector<4xf32>
168}
169
170// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
171//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
172//       CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
173
174// This test is mainly to catch a bug that running
175// `InnerOuterDimReductionConversion` on this function results in an
176// infinite loop. So just check that some value is returned.
177func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
178  %0 = vector.multi_reduction #vector.kind<maxnumf>, %arg0, %acc [0] : vector<2xf32> to f32
179  return %0 : f32
180}
181// CHECK-LABEL: func @vector_reduction_1D
182//       CHECK:   return %{{.+}}
183
184func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
185  %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
186  return %0 : f32
187}
188// CHECK-LABEL: func @vector_multi_reduction_to_scalar
189//       CHECK:   return %{{.+}}
190
191module attributes {transform.with_named_sequence} {
192  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
193    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
194    transform.apply_patterns to %func_op {
195      transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
196    } : !transform.op<"func.func">
197    transform.yield
198  }
199}
200