xref: /llvm-project/mlir/test/Dialect/Vector/vector-transforms.mlir (revision 6626ed6f9fae79d35aba504f50bac4375686a03b)
1// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s
2
3// CHECK-DAG: #[[MAP1:map[0-9]*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
4
5// CHECK-LABEL: func @add4x2
6//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
7// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
8// CHECK-NEXT: %[[A1:.*]] = arith.addf %[[S1]], %[[S2]] : vector<2x2xf32>
9// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[A1]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
10// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
11// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
12// CHECK-NEXT: %[[A2:.*]] = arith.addf %[[S3]], %[[S4]] : vector<2x2xf32>
13// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[A2]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32>
14// CHECK-NEXT: return %[[VEC1:.*]] : vector<4x2xf32>
15
16func.func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
17  %1 = arith.addf %0, %0: vector<4x2xf32>
18  return %1: vector<4x2xf32>
19}
20
21// Regression test. Previously, this example would trigger
22// CastAwayElementwiseLeadingOneDim as:
23//    * `vector<2x[4]x1xf32>`, would be reformulated as
24//    * `vector<2x4x1xf32>`.
25// With the updated shape, the conversion pattern would incorrectly assume that
26// some leading dims have been dropped.
27// CHECK-LABEL:   func.func @no_change(
28// CHECK-SAME:      %[[VAL_0:.*]]: vector<2x[4]x1xf32>,
29// CHECK-SAME:      %[[VAL_1:.*]]: vector<2x[4]x1xf32>)
30// CHECK-NEXT:           %[[VAL_2:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : vector<2x[4]x1xf32>
31// CHECK-NEXT:           return %[[VAL_2]]
32func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) -> vector<2x[4]x1xf32> {
33  %1 = arith.mulf %arg0, %arg1 : vector<2x[4]x1xf32>
34  return %1 : vector<2x[4]x1xf32>
35}
36
37// CHECK-LABEL:   func.func @cast_away_leading_one_dim(
38// CHECK:           %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32>
39// CHECK:           vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
40func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> {
41  %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
42  return %1: vector<1x4x1xf32>
43}
44
45// CHECK-LABEL:   func.func @cast_away_leading_one_dim_scalable(
46// CHECK:           %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32>
47// CHECK:           vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
48func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> {
49  %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32>
50  return %1: vector<1x[4]x1xf32>
51}
52
53// CHECK-LABEL: func @add4x4
54//      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
55// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
56
57// CHECK-NEXT: %[[A1:.*]] = arith.addf %[[S1]], %[[S2]] : vector<2x2xf32>
58
59// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
60// CHECK-NEXT: %[[S4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
61
62// CHECK-NEXT: %[[A2:.*]] = arith.addf %[[S3]], %[[S4]] : vector<2x2xf32>
63
64// CHECK-NEXT: %[[S5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
65// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
66// CHECK-NEXT: %[[A3:.*]] = arith.addf %[[S5]], %[[S6]] : vector<2x2xf32>
67
68// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
69// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
70// CHECK-NEXT: %[[A4:.*]] = arith.addf %[[S7]], %[[S8]] : vector<2x2xf32>
71
72// CHECK-NEXT: %[[S9:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
73// CHECK-NEXT: %[[A5:.*]] = arith.addf %[[S9]], %[[A1]] : vector<2x2xf32>
74// CHECK-NEXT: %[[R1:.*]] = vector.insert_strided_slice %[[A5]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
75
76
77// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
78// CHECK-NEXT: %[[A6:.*]] = arith.addf %[[S11]], %[[A2]] : vector<2x2xf32>
79// CHECK-NEXT: %[[R2:.*]] = vector.insert_strided_slice %[[A6]], %[[R1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
80
81// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
82// CHECK-NEXT: %[[A7:.*]] = arith.addf %[[S13]], %[[A3]] : vector<2x2xf32>
83// CHECK-NEXT: %[[R3:.*]] = vector.insert_strided_slice %[[A7]], %[[R2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
84
85// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
86// CHECK-NEXT: %[[A8:.*]] = arith.addf %[[S15]], %[[A4]] : vector<2x2xf32>
87// CHECK-NEXT: %[[R4:.*]] = vector.insert_strided_slice %[[A8]], %[[R3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
88
89// CHECK-NEXT: return %[[R4]] : vector<4x4xf32>
90
91func.func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
92  %2 = arith.addf %0, %1: vector<4x4xf32>
93  %3 = arith.addf %1, %2: vector<4x4xf32>
94  return %3: vector<4x4xf32>
95}
96
97// CHECK-LABEL: func @contraction4x4_ikj_xfer_read
98
99// CHECK-DAG:      %[[C2:.*]] = arith.constant 2 : index
100// CHECK-DAG:      %[[C0:.*]] = arith.constant 0 : index
101
102// Check LHS vector.transfer read is split for each user.
103
104//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
105// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
106
107// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
108// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
109
110// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
111// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
112// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
113// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
114
115// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
116// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
117// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
118// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
119
120// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
121// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
122// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
123// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32>
124// CHECK-NEXT: return
125
126#contraction_accesses1 = [
127  affine_map<(i, k, j) -> (i, k)>,
128  affine_map<(i, k, j) -> (k, j)>,
129  affine_map<(i, k, j) -> (i, j)>
130]
131#contraction_trait1 = {
132  indexing_maps = #contraction_accesses1,
133  iterator_types = ["parallel", "reduction", "parallel"]
134}
135
136func.func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
137                                   %arg1 : memref<2x4xf32>,
138                                   %arg2 : memref<4x4xf32>) {
139  %c0 = arith.constant 0 : index
140  %cf0 = arith.constant 0.0 : f32
141
142  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0
143    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
144      : memref<4x2xf32>, vector<4x2xf32>
145
146  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0
147    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
148    : memref<2x4xf32>, vector<2x4xf32>
149
150  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0
151    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
152      : memref<4x4xf32>, vector<4x4xf32>
153
154  %3 = vector.contract #contraction_trait1 %0, %1, %2
155      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
156
157  vector.transfer_write %3, %arg2[%c0, %c0]
158    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
159      : vector<4x4xf32>, memref<4x4xf32>
160  return
161}
162
163// TODO: Update test with VTR split transform.
164// CHECK-LABEL: func @vector_transfers
165// CHECK-COUNT-8: vector.transfer_read
166// CHECK-COUNT-4: arith.addf
167// CHECK-COUNT-4: vector.transfer_write
168
169func.func @vector_transfers(%arg0: index, %arg1: index) {
170  %cst = arith.constant 0.000000e+00 : f32
171  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
172  %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
173  %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
174  %cst_0 = arith.constant 1.000000e+00 : f32
175  %cst_1 = arith.constant 2.000000e+00 : f32
176  affine.for %arg2 = 0 to %arg0 step 4 {
177    affine.for %arg3 = 0 to %arg1 step 4 {
178      %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
179      %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
180      %6 = arith.addf %4, %5 : vector<4x4xf32>
181      vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32>
182    }
183  }
184  return
185}
186
187// CHECK-LABEL: func @elementwise_unroll
188//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
189//       CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
190//       CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
191//       CHECK:   %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
192//       CHECK:   %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
193//       CHECK:   %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
194//       CHECK:   %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
195//       CHECK:   %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
196//       CHECK:   %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
197//       CHECK:   %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
198//       CHECK:   %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
199//       CHECK:   %[[CMP0:.*]] = arith.cmpf ult, %[[VT0]], %[[VT4]] : vector<2x2xf32>
200//       CHECK:   %[[CMP1:.*]] = arith.cmpf ult, %[[VT1]], %[[VT5]] : vector<2x2xf32>
201//       CHECK:   %[[CMP2:.*]] = arith.cmpf ult, %[[VT2]], %[[VT6]] : vector<2x2xf32>
202//       CHECK:   %[[CMP3:.*]] = arith.cmpf ult, %[[VT3]], %[[VT7]] : vector<2x2xf32>
203//       CHECK:   %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
204//       CHECK:   %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
205//       CHECK:   %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
206//       CHECK:   %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
207//       CHECK:   %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
208//       CHECK:   %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
209//       CHECK:   %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
210//       CHECK:   %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
211//       CHECK:   %[[SEL0:.*]] = arith.select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
212//       CHECK:   %[[SEL1:.*]] = arith.select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
213//       CHECK:   %[[SEL2:.*]] = arith.select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
214//       CHECK:   %[[SEL3:.*]] = arith.select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32>
215//       CHECK:   vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
216//       CHECK:   vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
217//       CHECK:   vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
218//       CHECK:   vector.transfer_write %[[SEL3]], %[[ARG0]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
219func.func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
220  %c0 = arith.constant 0 : index
221  %cf0 = arith.constant 0.0 : f32
222  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
223  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
224  %cond = arith.cmpf ult, %0, %1 : vector<4x4xf32>
225  // Vector transfer split pattern only support single user right now.
226  %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
227  %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
228  %4 = arith.select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32>
229  vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
230  return
231}
232
233// Check that vector.transfer read/write are split based on contract unrolling.
234//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
235// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
236
237// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
238// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
239
240// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
241// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
242// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
243// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
244
245// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
246// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
247// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
248// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
249
250// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
251// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
252// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
253// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>
254// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
255
256func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
257                                          %arg1 : tensor<2x4xf32>,
258                                          %arg2 : tensor<4x4xf32>) ->
259  tensor<4x4xf32> {
260  %c0 = arith.constant 0 : index
261  %cf0 = arith.constant 0.0 : f32
262  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 :
263    tensor<4x2xf32>, vector<4x2xf32>
264  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 :
265    tensor<2x4xf32>, vector<2x4xf32>
266  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 :
267    tensor<4x4xf32>, vector<4x4xf32>
268  %3 = vector.contract #contraction_trait1 %0, %1, %2
269      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
270  %r = vector.transfer_write %3, %arg2[%c0, %c0]
271      : vector<4x4xf32>, tensor<4x4xf32>
272  return %r : tensor<4x4xf32>
273}
274
275// CHECK-LABEL: func @bubble_down_bitcast_in_extract
276//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
277func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
278  %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
279  // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : f32 from vector<4xf32>
280  // CHECK:  %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32>
281  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16>
282  // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : f16 from vector<2xf16>
283  %1 = vector.extract %0[3] : f16 from vector<8xf16>
284  // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : f32 from vector<4xf32>
285  // CHECK:  %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32>
286  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16>
287  // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : f16 from vector<2xf16>
288  %2 = vector.extract %0[4] : f16 from vector<8xf16>
289  // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
290  return %1, %2: f16, f16
291}
292
293// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract
294//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
295func.func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> {
296  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
297  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16>
298  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
299  %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
300  // CHECK: return %[[CAST]]
301  return %0: vector<4xf16>
302}
303
304// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim
305//  CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32>
306func.func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> {
307  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32>
308  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16>
309  %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16>
310  %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16>
311  // CHECK: return %[[CAST]]
312  return %0: vector<2x4xf16>
313}
314
315// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset
316func.func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> {
317  // CHECK: vector.bitcast
318  // CHECK-NEXT: vector.extract_strided_slice
319  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
320  %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
321  return %0: vector<4xf16>
322}
323
324// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size
325func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> {
326  // CHECK: vector.bitcast
327  // CHECK-NEXT: vector.extract_strided_slice
328  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
329  %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16>
330  return %0: vector<3xf16>
331}
332
333// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_i4_i8(
334// CHECK-SAME:                                                 %[[VAL:.*]]: vector<32xi4>,
335// CHECK-SAME:                                                 %[[DST:.*]]: vector<8x32xi4>) -> vector<8x16xi8> {
336func.func @bubble_up_bitcast_in_insert_i4_i8(%val: vector<32xi4>, %src: vector<8x32xi4>) -> vector<8x16xi8> {
337// CHECK:           %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<32xi4> to vector<16xi8>
338// CHECK:           %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x32xi4> to vector<8x16xi8>
339// CHECK:           vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xi8> into vector<8x16xi8>
340  %0 = vector.insert %val, %src[4] : vector<32xi4> into vector<8x32xi4>
341  %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
342  return %1 : vector<8x16xi8>
343}
344
345// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_i8_i4(
346// CHECK-SAME:                                                 %[[VAL:.*]]: vector<16xi8>,
347// CHECK-SAME:                                                 %[[DST:.*]]: vector<8x16xi8>) -> vector<8x32xi4> {
348func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8x16xi8>) -> vector<8x32xi4> {
349// CHECK:           %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi8> to vector<32xi4>
350// CHECK:           %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi8> to vector<8x32xi4>
351// CHECK:           vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<32xi4> into vector<8x32xi4>
352  %0 = vector.insert %val, %src[4] : vector<16xi8> into vector<8x16xi8>
353  %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
354  return %1 : vector<8x32xi4>
355}
356
357// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_i32_f32(
358// CHECK-SAME:                                                 %[[VAL:.*]]: vector<16xi32>,
359// CHECK-SAME:                                                 %[[DST:.*]]: vector<8x16xi32>) -> vector<8x16xf32> {
360func.func @bubble_up_bitcast_in_insert_i32_f32(%val: vector<16xi32>, %src: vector<8x16xi32>) -> vector<8x16xf32> {
361// CHECK:           %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi32> to vector<16xf32>
362// CHECK:           %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi32> to vector<8x16xf32>
363// CHECK:           vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xf32> into vector<8x16xf32>
364  %0 = vector.insert %val, %src[4] : vector<16xi32> into vector<8x16xi32>
365  %1 = vector.bitcast %0 : vector<8x16xi32> to vector<8x16xf32>
366  return %1 : vector<8x16xf32>
367}
368
369// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_scalar(
370func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
371// CHECK:           vector.insert
372// CHECK-NEXT:      vector.bitcast
373  %0 = vector.insert %val, %src[4, 8] : i8 into vector<8x16xi8>
374  %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
375  return %1 : vector<8x32xi4>
376}
377
378// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
379//  CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
380func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
381  // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32>
382  // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32>
383  // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32>
384  // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
385  // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
386  %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
387  %1 = vector.insert_strided_slice %src2, %0   {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
388  %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32>
389  // CHECK: return %[[INSERT2]]
390  return %cast: vector<4xf32>
391}
392
393// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset
394func.func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> {
395  // CHECK: vector.insert_strided_slice
396  // CHECK-NEXT: vector.bitcast
397  %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16>
398  %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
399  return %cast: vector<4xf32>
400}
401
402// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank
403func.func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> {
404  // CHECK: vector.insert_strided_slice
405  // CHECK-NEXT: vector.bitcast
406  %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16>
407  %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
408  return %cast: vector<16x4x4xf32>
409}
410
411// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape
412func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> {
413  // CHECK: vector.insert_strided_slice
414  // CHECK-NEXT: vector.bitcast
415  %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
416  %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32>
417  return %cast: vector<1xf32>
418}
419
420// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape
421func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> {
422  // CHECK: vector.insert_strided_slice
423  // CHECK-NEXT: vector.bitcast
424  %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16>
425  %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
426  return %cast: vector<4xf32>
427}
428
429// Make sure not crash on 0-D vector.
430// CHECK-LABEL:func.func @vec_0D
431// CHECK-NEXT:vector.bitcast
432func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
433  %0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
434  return %0 : vector<i32>
435}
436
437// Make sure not crash on dynamic index `vector.extract`:
438func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
439  %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
440  %1 = vector.extract %0[%index] : i16 from vector<8xi16>
441  return %1 : i16
442}
443
444// CHECK-LABEL: func.func @vector_extract_dynamic_index
445// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
446// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
447// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
448// CHECK: return %[[EXTRACT]]
449