xref: /llvm-project/mlir/test/Dialect/Vector/vector-unroll-options.mlir (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
2// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1"  | FileCheck %s --check-prefix=ORDER
3// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=0,3,1,2" | FileCheck %s --check-prefix=BATCHED
4
5func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>,
6                          %init : vector<8x8xf32>) -> vector<8x8xf32> {
7  %0 = vector.contract
8         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
9                           affine_map<(i, j, k) -> (j, k)>,
10                           affine_map<(i, j, k) -> (i, j)>],
11          iterator_types = ["parallel", "parallel", "reduction"]}
12       %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
13  return %0 : vector<8x8xf32>
14}
15// CHECK-LABEL: func @vector_contract_f32
16// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
17
18//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
19//  CHECK-SAME:   offsets = [0, 0]
20//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
21//  CHECK-SAME:   offsets = [0, 0]
22//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
23//  CHECK-SAME:   offsets = [0, 0]
24//       CHECK:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
25//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
26
27//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
28//  CHECK-SAME:   offsets = [0, 2]
29//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
30//  CHECK-SAME:   offsets = [0, 2]
31//       CHECK:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
32//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
33
34//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
35//  CHECK-SAME:   offsets = [0, 0]
36//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
37//  CHECK-SAME:   offsets = [4, 0]
38//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
39//  CHECK-SAME:   offsets = [0, 4]
40//       CHECK:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
41//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
42
43//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
44//  CHECK-SAME:   offsets = [0, 2]
45//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
46//  CHECK-SAME:   offsets = [4, 2]
47//       CHECK:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
48//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
49
50//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
51//  CHECK-SAME:   offsets = [4, 0]
52//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
53//  CHECK-SAME:   offsets = [0, 0]
54//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
55//  CHECK-SAME:   offsets = [4, 0]
56//       CHECK:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
57//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
58
59//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
60//  CHECK-SAME:   offsets = [4, 2]
61//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
62//  CHECK-SAME:   offsets = [0, 2]
63//       CHECK:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]]
64//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
65
66//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
67//  CHECK-SAME:   offsets = [4, 0]
68//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
69//  CHECK-SAME:   offsets = [4, 0]
70//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
71//  CHECK-SAME:   offsets = [4, 4]
72//       CHECK:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
73//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
74
75//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
76//  CHECK-SAME:   offsets = [4, 2]
77//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
78//  CHECK-SAME:   offsets = [4, 2]
79//       CHECK:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]]
80//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
81
82//       CHECK:   return
83
84// ORDER-LABEL: func @vector_contract_f32
85// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
86
87//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
88//  ORDER-SAME:   offsets = [0, 0]
89//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
90//  ORDER-SAME:   offsets = [0, 0]
91//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
92//  ORDER-SAME:   offsets = [0, 0]
93//       ORDER:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
94//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
95
96//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
97//  ORDER-SAME:   offsets = [0, 0]
98//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
99//  ORDER-SAME:   offsets = [4, 0]
100//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
101//  ORDER-SAME:   offsets = [0, 4]
102//       ORDER:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
103//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
104
105//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
106//  ORDER-SAME:   offsets = [4, 0]
107//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
108//  ORDER-SAME:   offsets = [0, 0]
109//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
110//  ORDER-SAME:   offsets = [4, 0]
111//       ORDER:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
112//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
113
114//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
115//  ORDER-SAME:   offsets = [4, 0]
116//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
117//  ORDER-SAME:   offsets = [4, 0]
118//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]]
119//  ORDER-SAME:   offsets = [4, 4]
120//       ORDER:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
121//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
122
123//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
124//  ORDER-SAME:   offsets = [0, 2]
125//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
126//  ORDER-SAME:   offsets = [0, 2]
127//       ORDER:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
128//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
129
130//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
131//  ORDER-SAME:   offsets = [0, 2]
132//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
133//  ORDER-SAME:   offsets = [4, 2]
134//       ORDER:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]]
135//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
136
137//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
138//  ORDER-SAME:   offsets = [4, 2]
139//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
140//  ORDER-SAME:   offsets = [0, 2]
141//       ORDER:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
142//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
143
144//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]]
145//  ORDER-SAME:   offsets = [4, 2]
146//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]]
147//  ORDER-SAME:   offsets = [4, 2]
148//       ORDER:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]]
149//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
150
151//       ORDER:   return
152
153
154
155func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
156                          %init : vector<8x8xf16>) -> vector<8x8xf16> {
157  %0 = vector.contract
158         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
159                           affine_map<(i, j, k) -> (j, k)>,
160                           affine_map<(i, j, k) -> (i, j)>],
161          iterator_types = ["parallel", "parallel", "reduction"]}
162       %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16>
163  return %0 : vector<8x8xf16>
164}
165// CHECK-LABEL: func @vector_contract_f16
166//       CHECK:   vector.contract {
167//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
168//       CHECK:   vector.contract {
169//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
170//       CHECK:   vector.contract {
171//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
172//       CHECK:   vector.contract {
173//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
174//       CHECK:   vector.contract {
175//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
176//       CHECK:   vector.contract {
177//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
178//       CHECK:   vector.contract {
179//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
180//       CHECK:   vector.contract {
181//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
182//       CHECK:   return
183
184func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>) -> vector<4x4xf32> {
185  %0 = vector.fma %a, %b, %c: vector<4x4xf32>
186  return %0 : vector<4x4xf32>
187}
188//   CHECK-LABEL: func @vector_fma
189// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
190
191func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
192  %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
193  return %0 : vector<4xf32>
194}
195// CHECK-LABEL: func @vector_multi_reduction
196//       CHECK:   %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
197//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
198//       CHECK:   %[[ACC0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
199//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]], %[[ACC0]] [1] : vector<2x2xf32> to vector<2xf32>
200//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
201//       CHECK:   %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]], %[[R0]] [1] : vector<2x2xf32> to vector<2xf32>
202//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
203//       CHECK:   %[[R2:.*]] = vector.multi_reduction <add>, %[[E2]], %[[R1]] [1] : vector<2x2xf32> to vector<2xf32>
204//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
205//       CHECK:   %[[ACC1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
206//       CHECK:   %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]], %[[ACC1]] [1] : vector<2x2xf32> to vector<2xf32>
207//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
208//       CHECK:   %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]], %[[R3]] [1] : vector<2x2xf32> to vector<2xf32>
209//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
210//       CHECK:   %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]], %[[R4]] [1] : vector<2x2xf32> to vector<2xf32>
211//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[R2]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
212//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
213//       CHECK:   return %[[V2]] : vector<4xf32>
214
215
216func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
217  %0 = vector.reduction <add>, %v : vector<8xf32> into f32
218  return %0 : f32
219}
220// CHECK-LABEL: func @vector_reduction(
221//  CHECK-SAME:     %[[v:.*]]: vector<8xf32>
222//       CHECK:   %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
223//       CHECK:   %[[r0:.*]] = vector.reduction <add>, %[[s0]]
224//       CHECK:   %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2]
225//       CHECK:   %[[r1:.*]] = vector.reduction <add>, %[[s1]]
226//       CHECK:   %[[add1:.*]] = arith.addf %[[r0]], %[[r1]]
227//       CHECK:   %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2]
228//       CHECK:   %[[r2:.*]] = vector.reduction <add>, %[[s2]]
229//       CHECK:   %[[add2:.*]] = arith.addf %[[add1]], %[[r2]]
230//       CHECK:   %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2]
231//       CHECK:   %[[r3:.*]] = vector.reduction <add>, %[[s3]]
232//       CHECK:   %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
233//       CHECK:   return %[[add3]]
234
235func.func @vector_transpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
236  %t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32>
237  return %t : vector<2x3x8x4xf32>
238}
239// CHECK-LABEL: func @vector_transpose
240//       CHECK:   %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32>
241//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
242//       CHECK:   %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
243//       CHECK:   %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
244//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
245//       CHECK:   %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
246//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
247//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
248//       CHECK:   %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
249//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
250//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
251//       CHECK:   %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
252//       CHECK:   %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
253//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
254//       CHECK:   %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
255//       CHECK:   %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
256//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
257//       CHECK:   %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
258//       CHECK:   %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
259//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
260//       CHECK:   %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
261//       CHECK:   %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
262//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
263//       CHECK:   %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
264//       CHECK:   %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
265//       CHECK:   return %[[V7]] : vector<2x3x8x4xf32>
266
267// -----
268
269func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf32>, %init: vector<8x8x8xf32>) -> vector<8x8x8xf32> {
270  %0 = vector.contract
271         {indexing_maps = [affine_map<(d0,d1,d2,c0) -> (d0,d1,c0)>,
272                           affine_map<(d0,d1,d2,c0) -> (d0,d2,c0)>,
273                           affine_map<(d0,d1,d2,c0) -> (d0,d1,d2)>],
274          iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
275       %lhs, %rhs, %init : vector<8x8x4xf32>, vector<8x8x4xf32> into vector<8x8x8xf32>
276  return %0 : vector<8x8x8xf32>
277}
278
279
280//    CHECK-LABEL: vector_contract_batched
281// CHECK-COUNT-16: vector.contract
282//      CHECK-NOT: vector.contract
283//          CHECK: return
284
285//    UNROLL-LABEL: vector_contract_batched
286//  UNROLL-COUNT-1: vector.contract
287//      UNROLL-NOT: vector.contract
288//          UNROLL: return
289
290
291//    BATCHED-LABEL: vector_contract_batched
292// BATCHED-COUNT-16: vector.contract
293//      BATCHED-NOT: vector.contract
294//          BATCHED: return
295