xref: /llvm-project/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (revision a5985ca51dd7e0759d65fac9cb2b6a4448ebc404)
1// RUN: mlir-opt %s -test-vector-to-vector-lowering -split-input-file| FileCheck %s
2
3// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
4// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
5// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
6
7// CHECK-LABEL: cast_away_contraction_leading_one_dims
8//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
9//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
10//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
11//  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
12//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
13//  CHECK-SAME:   %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
14//  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
15//  CHECK-NEXT:  return %[[R4]] : vector<1x16x16xf32>
16
17#contraction_accesses0 = [
18  affine_map<(l, i, j, k) -> (l, i, k)>,
19  affine_map<(l, i, j, k) -> (l, k, j)>,
20  affine_map<(l, i, j, k) -> (l, i, j)>
21]
22#contraction_trait0 = {
23  indexing_maps = #contraction_accesses0,
24  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
25}
26
27func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
28  %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
29  return %0: vector<1x16x16xf32>
30}
31
32// -----
33// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
34// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
35// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
36
37// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_const_mask
38// CHECK:           %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
39// CHECK:           %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
40// CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
41// CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
42// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
43// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
44// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
45// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
46// CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
47// CHECK:           return %[[RES]] : vector<1x16x16xf32>
48
49#contraction_accesses0 = [
50  affine_map<(l, i, j, k) -> (l, i, k)>,
51  affine_map<(l, i, j, k) -> (l, k, j)>,
52  affine_map<(l, i, j, k) -> (l, i, j)>
53]
54#contraction_trait0 = {
55  indexing_maps = #contraction_accesses0,
56  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
57}
58
59func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
60  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
61  %0 = vector.mask %mask {
62    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
63  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
64  return %0 : vector<1x16x16xf32>
65}
66
67// -----
68// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
69// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
70// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
71
72// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_mask
73// CHECK:           %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
74// CHECK:           %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
75// CHECK:           %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
76// CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
77// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
78// CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
79// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
80// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
81// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
82// CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
83
84#contraction_accesses0 = [
85  affine_map<(l, i, j, k) -> (l, i, k)>,
86  affine_map<(l, i, j, k) -> (l, k, j)>,
87  affine_map<(l, i, j, k) -> (l, i, j)>
88]
89#contraction_trait0 = {
90  indexing_maps = #contraction_accesses0,
91  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
92}
93
94func.func @cast_away_contraction_leading_one_dim_under_mask(
95  %arg0: vector<1x16x8xf32>,
96  %arg1: vector<1x8x16xf32>,
97  %arg2: vector<1x16x16xf32>,
98  %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
99  %0 = vector.mask %mask {
100    vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
101  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
102  return %0: vector<1x16x16xf32>
103}
104
105// -----
106
107// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
108// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
109// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
110
111// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded
112//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
113//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
114//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0, 0] : vector<16xf32> from vector<1x1x16xf32>
115//  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
116//  CHECK-SAME:   iterator_types = ["parallel", "reduction"], kind = #vector.kind<mul>}
117//  CHECK-SAME:   %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32>
118//  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32>
119//  CHECK-NEXT:   %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32>
120//  CHECK-NEXT:  return %[[R5]] : vector<1x1x16xf32>
121
122#contraction_accesses1 = [
123  affine_map<(l, i, j, k) -> (i, l, k)>,
124  affine_map<(l, i, j, k) -> (l, k, j)>,
125  affine_map<(l, i, j, k) -> (l, i, j)>
126]
127#contraction_trait1 = {
128  indexing_maps = #contraction_accesses1,
129  iterator_types = ["parallel", "parallel", "parallel", "reduction"],
130  kind = #vector.kind<mul>
131}
132
133func.func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<1x1x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x1x16xf32>) -> vector<1x1x16xf32> {
134  %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2  : vector<1x1x8xf32>, vector<1x8x16xf32> into vector<1x1x16xf32>
135  return %0: vector<1x1x16xf32>
136}
137
138// -----
139// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
140// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
141// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
142
143// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2
144//  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
145//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %[[R0]][0] : vector<8x16xf32> from vector<1x8x16xf32>
146//  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
147//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<2x8xf32> from vector<1x2x8xf32>
148//  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0] : vector<2x16xf32> from vector<1x2x16xf32>
149//  CHECK-NEXT:   %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
150//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
151//  CHECK-SAME:   %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
152//  CHECK-NEXT:   %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
153//  CHECK-NEXT:  return %[[R6]] : vector<1x2x16xf32>
154
155#contraction_accesses2 = [
156  affine_map<(l, i, j, k) -> (k, l, j)>,
157  affine_map<(l, i, j, k) -> (i, k, l)>,
158  affine_map<(l, i, j, k) -> (l, i, j)>
159]
160#contraction_trait2 = {
161  indexing_maps = #contraction_accesses2,
162  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
163}
164
165
166func.func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector<8x1x16xf32>, %arg1: vector<2x8x1xf32>, %arg2: vector<1x2x16xf32>) -> vector<1x2x16xf32> {
167  %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2  : vector<8x1x16xf32>, vector<2x8x1xf32> into vector<1x2x16xf32>
168  return %0: vector<1x2x16xf32>
169}
170
171// -----
172// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
173// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
174// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
175
176
177// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4
178//  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<8x1x16xf32> from vector<1x8x1x16xf32>
179//  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<2x8x1xf32> from vector<1x2x8x1xf32>
180//  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
181//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<8x16xf32> from vector<1x8x16xf32>
182//  CHECK-NEXT:   %[[R4:.+]] =  vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
183//  CHECK-NEXT:   %[[R5:.+]] =  vector.extract %[[R4]][0] : vector<2x8xf32> from vector<1x2x8xf32>
184//  CHECK-NEXT:   %[[R6:.+]] =  vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
185//  CHECK-NEXT:   %[[R7:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
186//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
187//  CHECK-SAME:   %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
188//  CHECK-NEXT:   %[[R8:.+]] =  vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32>
189//  CHECK-NEXT:   %[[R9:.+]] =  vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
190//  CHECK-NEXT:  return %[[R9]] : vector<1x1x2x16xf32>
191
192#contraction_accesses2 = [
193  affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
194  affine_map<(m, l, i, j, k) -> (m, i, k, l)>,
195  affine_map<(m, l, i, j, k) -> (m, l, i, j)>
196]
197#contraction_trait2 = {
198  indexing_maps = #contraction_accesses2,
199  iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"]
200}
201
202
203func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> {
204  %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2  : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32>
205  return %0: vector<1x1x2x16xf32>
206}
207
208// -----
209// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
210// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
211// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
212
213// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose
214//  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32>
215//  CHECK-NEXT:   %[[R1:.+]] =  vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32>
216//  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %[[R0]][0, 0] : vector<8x16xf32> from vector<1x1x8x16xf32>
217//  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R1]][0, 0] : vector<2x8xf32> from vector<1x1x2x8xf32>
218//  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
219//  CHECK-NEXT:   %[[R5:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
220//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
221//  CHECK-SAME:   %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
222//  CHECK-NEXT:   %[[R6:.+]] =  vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
223//  CHECK-NEXT:   %[[R7:.+]] =  vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
224//  CHECK-NEXT:  return %[[R7]] : vector<1x1x2x16xf32>
225
226#contraction_accesses3 = [
227  affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
228  affine_map<(m, l, i, j, k) -> (m, i, k, l)>,
229  affine_map<(m, l, i, j, k) -> (l, m, i, j)>
230]
231#contraction_trait3 = {
232  indexing_maps = #contraction_accesses3,
233  iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"]
234}
235
236func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> {
237  %0 = vector.contract #contraction_trait3 %arg0, %arg1, %arg2  : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32>
238  return %0: vector<1x1x2x16xf32>
239}
240
241// -----
242
243// CHECK-LABEL:   func.func @cast_away_contraction_does_not_transpose_leading_unit_dims
244// CHECK-NOT:  vector.transpose
245// CHECK:           vector.contract
246func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
247                          %rhs: vector<1x8x8xi32>,
248                          %acc: vector<1x8xi32>) -> vector<1x8xi32> {
249  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
250  return %result : vector<1x8xi32>
251}
252
253// -----
254// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
255func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
256  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
257  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
258  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
259  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
260  // CHECK: return %[[RET]]
261  return %0: vector<1x1x8xf16>
262}
263
264// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
265func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
266  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
267  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
268  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
269  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
270  // CHECK: return %[[RET]]
271  return %0: vector<1x1x[8]xf16>
272}
273
274// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
275func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
276  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
277  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
278  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
279  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
280  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
281  // CHECK: return %[[RET]]
282  return %0: vector<1x8x8xf16>
283}
284
285// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
286func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
287  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
288  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
289  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
290  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
291  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
292  // CHECK: return %[[RET]]
293  return %0: vector<1x8x[8]xf16>
294}
295
296// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
297//  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
298func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
299  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1xf16> from vector<1x1xf16>
300  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
301  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
302  // CHECK: return %[[B]]
303  return %0: vector<1x1x1xf16>
304}
305
306// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable
307//  CHECK-SAME: %[[ARG0:.+]]: vector<1x[1]xf16>, %{{.+}}: vector<1x1x[1]xf16>
308func.func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable(%arg0: vector<1x[1]xf16>, %arg1: vector<1x1x[1]xf16>) -> vector<1x1x[1]xf16> {
309  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<[1]xf16> from vector<1x[1]xf16>
310  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<[1]xf16> to vector<1x1x[1]xf16>
311  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[1]xf16> into vector<1x1x[1]xf16>
312  // CHECK: return %[[B]]
313  return %0: vector<1x1x[1]xf16>
314}
315
316// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
317func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
318  // CHECK: %[[C0:.+]] = arith.constant 0 : index
319  %c0 = arith.constant 0 : index
320  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
321  %f0 = arith.constant 0. : f16
322  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
323  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
324  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
325  // CHECK: return %[[CAST]]
326  return %0: vector<1x4xf16>
327}
328
329// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
330func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
331  // CHECK: %[[C0:.+]] = arith.constant 0 : index
332  %c0 = arith.constant 0 : index
333  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
334  %f0 = arith.constant 0. : f16
335  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
336  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
337  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
338  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
339  // CHECK: return %[[CAST]]
340  return %0: vector<1x4xf16>
341}
342
343// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
344func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
345  %c0 = arith.constant 0 : index
346  %f0 = arith.constant 0. : f16
347  // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
348  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
349  return %0: vector<1x1xf16>
350}
351
352// -----
353
354// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
355// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
356func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
357  // CHECK: %[[C0:.+]] = arith.constant 0 : index
358  %c0 = arith.constant 0 : index
359  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
360  %f0 = arith.constant 0. : f16
361  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
362  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
363  // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
364  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
365  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
366                            permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
367  // CHECK: return %[[CAST]]
368  return %0: vector<1x1x4xf16>
369}
370
371// -----
372
373// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
374// CHECK:      %[[MASK:.+]] = vector.constant_mask
375// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
376// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
377// CHECK-SAME:   vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
378// CHECK:      return %[[RET]] : vector<1x4xf16>
379func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
380  %c0 = arith.constant 0 : index
381  %f0 = arith.constant 0. : f16
382  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
383  %ret = vector.mask %mask {
384    vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
385  } : vector<1x4xi1> -> vector<1x4xf16>
386  return %ret: vector<1x4xf16>
387}
388
389// -----
390
391// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
392func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
393  // CHECK: %[[C0:.+]] = arith.constant 0 : index
394  %c0 = arith.constant 0 : index
395  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
396  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
397
398  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
399  return
400}
401
402// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
403func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
404  // CHECK: %[[C0:.+]] = arith.constant 0 : index
405  %c0 = arith.constant 0 : index
406  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
407  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
408  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
409
410  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
411  return
412}
413
414// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
415func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
416  %c0 = arith.constant 0 : index
417  // CHECK: vector.extract %{{.+}}[0] : vector<1xf16> from vector<1x1xf16>
418  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
419  return
420}
421
422// -----
423
424// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
425// CHECK:      %[[MASK:.+]] = vector.constant_mask
426// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
427// CHECK:      vector.mask %[[CASTED_MASK]] {
428// CHECK-SAME:   vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
429// CHECK:      return
430func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
431  %c0 = arith.constant 0 : index
432  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
433  vector.mask %mask {
434    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
435  } : vector<1x4xi1>
436  return
437}
438
439// -----
440
441// CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
442// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
443func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
444  // CHECK: %[[C0:.+]] = arith.constant 0 : index
445  %c0 = arith.constant 0 : index
446  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
447  // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
448  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
449  // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
450
451  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
452                        permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
453  return
454}
455
456// -----
457
458// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
459func.func @cast_away_elementwise_leading_one_dims(
460  %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
461  %arg3: vector<1x4xf32>, %arg4: i1) ->
462  (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
463  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
464  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
465  // CHECK:  arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
466  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
467  %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
468  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
469  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
470  // CHECK:  arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
471  // CHECK:  vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
472  %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
473  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
474  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
475  // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
476  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
477  %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
478  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
479  // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
480  // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
481  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
482  %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32>
483  return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
484}
485
486// -----
487
488// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
489//  CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
490//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
491//       CHECK:   %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32>
492//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
493//       CHECK:   return %[[BCAST]]
494func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
495  %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32>
496  return %0: vector<1x1x4xf32>
497}
498
499// -----
500
501// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_scalable(
502// CHECK-SAME:    %[[S:.*]]: f32,
503// CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
504func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
505// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<[4]xf32> from vector<1x1x[4]xf32>
506// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
507// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
508// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
509  %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
510  return %0: vector<1x1x[4]xf32>
511}
512
513// -----
514
515// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
516// CHECK-SAME:    %[[S:.*]]: f32,
517// CHECK-SAME:    %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
518func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
519// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<[1]x4xf32> from vector<1x[1]x4xf32>
520// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
521// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
522// CHECK:           return %[[BCAST]] : vector<1x[1]x4xf32>
523  %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
524  return %0: vector<1x[1]x4xf32>
525}
526
527// -----
528
529// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
530//  CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
531//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
532//       CHECK:   return %[[BCAST]]
533func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
534  %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32>
535  return %0: vector<1x1x4xf32>
536}
537
538// -----
539
540// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank1_scalable(
541// CHECK-SAME:    %[[S:.*]]: vector<[4]xf32>,
542// CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
543// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
544// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
545func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
546  %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
547  return %0: vector<1x1x[4]xf32>
548}
549
550// -----
551
552// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
553//  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
554//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
555//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32>
556//       CHECK:   return %[[BCAST]]
557func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
558  %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32>
559  return %0: vector<1x1x4xf32>
560}
561
562// -----
563
564// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_scalable(
565// CHECK-SAME:    %[[S:.*]]: vector<1x[4]xf32>,
566// CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
567// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
568// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
569// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
570func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
571  %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
572  return %0: vector<1x1x[4]xf32>
573}
574
575// -----
576
577// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
578//  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
579//       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
580//       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<2x1x4xf32> from vector<1x2x1x4xf32>
581//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
582//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
583//       CHECK:   return %[[BCAST]]
584func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> {
585  %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32>
586  return %0: vector<1x2x1x4xf32>
587}
588
589// -----
590
591// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
592// CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
593// CHECK-SAME:      %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
594// CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
595// CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<2x1x[4]xf32> from vector<1x2x1x[4]xf32>
596// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
597// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
598// CHECK:           return %[[BCAST]] : vector<1x2x1x[4]xf32>
599func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
600  %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
601  return %0: vector<1x2x1x[4]xf32>
602}
603
604// -----
605
606// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
607//  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
608//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
609//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
610//       CHECK:   return %[[INSERT]]
611func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> {
612  %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
613  return %0: vector<8x1x4xf32>
614}
615
616// -----
617
618// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
619// CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
620// CHECK-SAME:      %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
621// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
622// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
623// CHECK:           return %[[INSERT]] : vector<8x1x[4]xf32>
624func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
625  %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
626  return %0: vector<8x1x[4]xf32>
627}
628
629// -----
630
631// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
632//  CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
633//       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1>
634//       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<8x1x8xi1> from vector<1x1x8x1x8xi1>
635//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
636//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
637//       CHECK:   return %[[BCAST]]
638func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> {
639  %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
640  return %0: vector<1x1x8x1x8xi1>
641}
642
643// -----
644
645// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
646// CHECK-SAME:      %[[S:.*]]: vector<1x[8]xi1>,
647// CHECK-SAME:      %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
648// CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[8]xi1> from vector<1x[8]xi1>
649// CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<8x1x[8]xi1> from vector<1x1x8x1x[8]xi1>
650// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
651// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
652// CHECK:           return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
653func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
654  %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
655  return %0: vector<1x1x8x1x[8]xi1>
656}
657
658// -----
659
660// CHECK-LABEL:   func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
661// CHECK:           %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
662// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
663// CHECK:           return %[[BCAST]] : vector<1x1x8x2x1xi1>
664func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
665  %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
666  return %0: vector<1x1x8x2x1xi1>
667}
668
669// -----
670
671// CHECK-LABEL:   func.func @drop_unit_dims_scalar_cond_select(
672// CHECK:           arith.select {{.*}} : vector<16xi1>
673func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> {
674  %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
675  return %sel : vector<1x16xi1>
676}
677