xref: /llvm-project/mlir/test/Dialect/Vector/vector-gather-lowering.mlir (revision b91d5af1ac3ad2c18b1dfde2061a6ac1d638e6e4)
1// RUN: mlir-opt %s --test-vector-gather-lowering | FileCheck %s
2// RUN: mlir-opt %s --test-vector-gather-lowering --canonicalize | FileCheck %s --check-prefix=CANON
3
4// CHECK-LABEL: @gather_memref_1d
5// CHECK-SAME:    ([[BASE:%.+]]: memref<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
6// CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0] : i1 from vector<2xi1>
7// CHECK-DAG:     %[[IDX0:.+]]  = vector.extract [[IDXVEC]][0] : index from vector<2xindex>
8// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<2xf32>)
9// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[IDX0]]] : memref<?xf32>, vector<1xf32>
10// CHECK-NEXT:      [[ELEM0:%.+]] = vector.extract [[LD0]][0] : f32 from vector<1xf32>
11// CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PASS]] [0] : f32 into vector<2xf32>
12// CHECK-NEXT:      scf.yield [[INS0]] : vector<2xf32>
13// CHECK-NEXT:    else
14// CHECK-NEXT:      scf.yield [[PASS]] : vector<2xf32>
15// CHECK-DAG:     [[M1:%.+]]    = vector.extract [[MASK]][1] : i1 from vector<2xi1>
16// CHECK-DAG:     %[[IDX1:.+]]  = vector.extract [[IDXVEC]][1] : index from vector<2xindex>
17// CHECK-NEXT:    [[RES1:%.+]]  = scf.if [[M1]] -> (vector<2xf32>)
18// CHECK-NEXT:      [[LD1:%.+]]   = vector.load [[BASE]][%[[IDX1]]] : memref<?xf32>, vector<1xf32>
19// CHECK-NEXT:      [[ELEM1:%.+]] = vector.extract [[LD1]][0] : f32 from vector<1xf32>
20// CHECK-NEXT:      [[INS1:%.+]]  = vector.insert [[ELEM1]], [[RES0]] [1] : f32 into vector<2xf32>
21// CHECK-NEXT:      scf.yield [[INS1]] : vector<2xf32>
22// CHECK-NEXT:    else
23// CHECK-NEXT:      scf.yield [[RES0]] : vector<2xf32>
24// CHECK:         return [[RES1]] : vector<2xf32>
25func.func @gather_memref_1d(%base: memref<?xf32>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
26  %c0 = arith.constant 0 : index
27  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
28  return %0 : vector<2xf32>
29}
30
31// CHECK-LABEL: @gather_memref_1d_i32_index
32// CHECK-SAME:    ([[BASE:%.+]]: memref<?xf32>, [[IDXVEC:%.+]]: vector<2xi32>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
33// CHECK-DAG:     [[C42:%.+]]   = arith.constant 42 : index
34// CHECK-DAG:     [[IDXS:%.+]]  = arith.index_cast [[IDXVEC]] : vector<2xi32> to vector<2xindex>
35// CHECK-DAG:     [[IDX0:%.+]]  = vector.extract [[IDXS]][0] : index from vector<2xindex>
36// CHECK-NEXT:    %[[OFF0:.+]]  = arith.addi [[IDX0]], [[C42]] : index
37// CHECK-NEXT:    [[RES0:%.+]]  = scf.if
38// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[OFF0]]] : memref<?xf32>, vector<1xf32>
39// CHECK:         else
40// CHECK:         [[IDX1:%.+]]  = vector.extract [[IDXS]][1] : index from vector<2xindex>
41// CHECK:         %[[OFF1:.+]]  = arith.addi [[IDX1]], [[C42]] : index
42// CHECK:         [[RES1:%.+]]  = scf.if
43// CHECK-NEXT:      [[LD1:%.+]]   = vector.load [[BASE]][%[[OFF1]]] : memref<?xf32>, vector<1xf32>
44// CHECK:         else
45// CHECK:         return [[RES1]] : vector<2xf32>
46func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
47  %c0 = arith.constant 42 : index
48  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2xi32>, vector<2xi1>, vector<2xf32> into vector<2xf32>
49  return %0 : vector<2xf32>
50}
51
52// CHECK-LABEL: @gather_memref_2d
53// CHECK-SAME:    ([[BASE:%.+]]: memref<?x?xf32>, [[IDXVEC:%.+]]: vector<2x3xindex>, [[MASK:%.+]]: vector<2x3xi1>, [[PASS:%.+]]: vector<2x3xf32>)
54// CHECK-DAG:     %[[C0:.+]]    = arith.constant 0 : index
55// CHECK-DAG:     %[[C1:.+]]    = arith.constant 1 : index
56// CHECK-DAG:     [[PTV0:%.+]]  = vector.extract [[PASS]][0] : vector<3xf32> from vector<2x3xf32>
57// CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0, 0] : i1 from vector<2x3xi1>
58// CHECK-DAG:     [[IDX0:%.+]]  = vector.extract [[IDXVEC]][0, 0] : index from vector<2x3xindex>
59// CHECK-NEXT:    %[[OFF0:.+]]  = arith.addi [[IDX0]], %[[C1]] : index
60// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<3xf32>)
61// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[C0]], %[[OFF0]]] : memref<?x?xf32>, vector<1xf32>
62// CHECK-NEXT:      [[ELEM0:%.+]] = vector.extract [[LD0]][0] : f32 from vector<1xf32>
63// CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PTV0]] [0] : f32 into vector<3xf32>
64// CHECK-NEXT:      scf.yield [[INS0]] : vector<3xf32>
65// CHECK-NEXT:    else
66// CHECK-NEXT:      scf.yield [[PTV0]] : vector<3xf32>
67// CHECK-COUNT-5: scf.if
68// CHECK:         [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : vector<3xf32> into vector<2x3xf32>
69// CHECK-NEXT:    return [[FINAL]] : vector<2x3xf32>
70 func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
71  %c0 = arith.constant 0 : index
72  %c1 = arith.constant 1 : index
73  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
74  return %0 : vector<2x3xf32>
75 }
76
77// CHECK-LABEL: @scalable_gather_memref_2d
78// CHECK-SAME:      %[[BASE:.*]]: memref<?x?xf32>,
79// CHECK-SAME:      %[[IDXVEC:.*]]: vector<2x[3]xindex>,
80// CHECK-SAME:      %[[MASK:.*]]: vector<2x[3]xi1>,
81// CHECK-SAME:      %[[PASS:.*]]: vector<2x[3]xf32>
82// CHECK:         %[[C0:.*]] = arith.constant 0 : index
83// CHECK:         %[[C1:.*]] = arith.constant 1 : index
84// CHECK:         %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
85// CHECK:         %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
86// CHECK:         %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
87// CHECK:         %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
88// CHECK:         %[[GATHER0:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC0]]], %[[MASK0]], %[[PASS0]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
89// CHECK:         %[[INS0:.*]] = vector.insert %[[GATHER0]], %[[INIT]] [0] : vector<[3]xf32> into vector<2x[3]xf32>
90// CHECK:         %[[IDXVEC1:.*]] = vector.extract %[[IDXVEC]][1] : vector<[3]xindex> from vector<2x[3]xindex>
91// CHECK:         %[[MASK1:.*]] = vector.extract %[[MASK]][1] : vector<[3]xi1> from vector<2x[3]xi1>
92// CHECK:         %[[PASS1:.*]] = vector.extract %[[PASS]][1] : vector<[3]xf32> from vector<2x[3]xf32>
93// CHECK:         %[[GATHER1:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC1]]], %[[MASK1]], %[[PASS1]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
94// CHECK:         %[[INS1:.*]] = vector.insert %[[GATHER1]], %[[INS0]] [1] : vector<[3]xf32> into vector<2x[3]xf32>
95// CHECK-NEXT:    return %[[INS1]] : vector<2x[3]xf32>
96func.func @scalable_gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x[3]xindex>, %mask: vector<2x[3]xi1>, %pass_thru: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
97 %c0 = arith.constant 0 : index
98 %c1 = arith.constant 1 : index
99 %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x[3]xindex>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
100 return %0 : vector<2x[3]xf32>
101}
102
103// CHECK-LABEL: @scalable_gather_cant_unroll
104// CHECK-NOT: extract
105// CHECK: vector.gather
106// CHECK-NOT: extract
107func.func @scalable_gather_cant_unroll(%base: memref<?x?xf32>, %v: vector<[4]x8xindex>, %mask: vector<[4]x8xi1>, %pass_thru: vector<[4]x8xf32>) -> vector<[4]x8xf32> {
108 %c0 = arith.constant 0 : index
109 %c1 = arith.constant 1 : index
110 %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<[4]x8xindex>, vector<[4]x8xi1>, vector<[4]x8xf32> into vector<[4]x8xf32>
111 return %0 : vector<[4]x8xf32>
112}
113
114// CHECK-LABEL: @gather_tensor_1d
115// CHECK-SAME:    ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
116// CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0] : i1 from vector<2xi1>
117// CHECK-DAG:     %[[IDX0:.+]]  = vector.extract [[IDXVEC]][0] : index from vector<2xindex>
118// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<2xf32>)
119// CHECK-NEXT:      [[ELEM0:%.+]] = tensor.extract [[BASE]][%[[IDX0]]] : tensor<?xf32>
120// CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PASS]] [0] : f32 into vector<2xf32>
121// CHECK-NEXT:      scf.yield [[INS0]] : vector<2xf32>
122// CHECK-NEXT:    else
123// CHECK-NEXT:      scf.yield [[PASS]] : vector<2xf32>
124// CHECK-DAG:     [[M1:%.+]]    = vector.extract [[MASK]][1] : i1 from vector<2xi1>
125// CHECK-DAG:     %[[IDX1:.+]]  = vector.extract [[IDXVEC]][1] : index from vector<2xindex>
126// CHECK-NEXT:    [[RES1:%.+]]  = scf.if [[M1]] -> (vector<2xf32>)
127// CHECK-NEXT:      [[ELEM1:%.+]] = tensor.extract [[BASE]][%[[IDX1]]] : tensor<?xf32>
128// CHECK-NEXT:      [[INS1:%.+]]  = vector.insert [[ELEM1]], [[RES0]] [1] : f32 into vector<2xf32>
129// CHECK-NEXT:      scf.yield [[INS1]] : vector<2xf32>
130// CHECK-NEXT:    else
131// CHECK-NEXT:      scf.yield [[RES0]] : vector<2xf32>
132// CHECK:         return [[RES1]] : vector<2xf32>
133func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
134  %c0 = arith.constant 0 : index
135  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
136  return %0 : vector<2xf32>
137}
138
139// CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
140// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
141// CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
142// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
143// CHECK:   %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32>
144// CHECK:   %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
145// CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
146// CHECK:   scf.yield %[[RES]] : vector<1xf32>
147// CHECK: } else {
148// CHECK:    scf.yield %arg3 : vector<1xf32>
149// CHECK: }
150// CHECK: return %[[RET]] : vector<1xf32>
151func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
152  %c0 = arith.constant 0 : index
153  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
154  return %0 : vector<1xf32>
155}
156
157// CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
158// CHECK: %[[CONST:.*]] = arith.constant 0 : index
159// CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
160// CHECK: return %[[RET]] : vector<2xf32>
161func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
162  %c0 = arith.constant 0 : index
163  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
164  return %0 : vector<2xf32>
165}
166
167// CHECK-LABEL: @gather_tensor_2d
168// CHECK:  scf.if
169// CHECK:    tensor.extract
170// CHECK:  else
171// CHECK:  scf.if
172// CHECK:    tensor.extract
173// CHECK:  else
174// CHECK:  scf.if
175// CHECK:    tensor.extract
176// CHECK:  else
177// CHECK:  scf.if
178// CHECK:    tensor.extract
179// CHECK:  else
180// CHECK:  scf.if
181// CHECK:    tensor.extract
182// CHECK:  else
183// CHECK:  scf.if
184// CHECK:    tensor.extract
185// CHECK:  else
186// CHECK:       [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : vector<3xf32> into vector<2x3xf32>
187// CHECK-NEXT:  return [[FINAL]] : vector<2x3xf32>
188 func.func @gather_tensor_2d(%base: tensor<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
189  %c0 = arith.constant 0 : index
190  %c1 = arith.constant 1 : index
191  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
192  return %0 : vector<2x3xf32>
193 }
194
195// Check that all-set and no-set maskes get optimized out after canonicalization.
196
197// CANON-LABEL: @gather_tensor_1d_all_set
198// CANON-NOT:     scf.if
199// CANON:         tensor.extract
200// CANON:         tensor.extract
201// CANON:         [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : f32 into vector<2xf32>
202// CANON-NEXT:    return [[FINAL]] : vector<2xf32>
203func.func @gather_tensor_1d_all_set(%base: tensor<?xf32>, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
204  %mask = arith.constant dense <true> : vector<2xi1>
205  %c0 = arith.constant 0 : index
206  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
207  return %0 : vector<2xf32>
208}
209
210// CANON-LABEL: @gather_tensor_1d_none_set
211// CANON-SAME:    ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[PASS:%.+]]: vector<2xf32>)
212// CANON-NEXT:    return [[PASS]] : vector<2xf32>
213func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
214  %mask = arith.constant dense <false> : vector<2xi1>
215  %c0 = arith.constant 0 : index
216  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
217  return %0 : vector<2xf32>
218}
219
220// Check that vector.gather of a strided memref is replaced with a
221// vector.gather with indices encoding the original strides. Note that multiple
222// patterns are run for this example, e.g.:
223  //  1. "remove stride from gather source"
224  //  2. "flatten gather"
225// However, the main goal is to the test Pattern 1 above.
226#map = affine_map<()[s0] -> (s0 * 4096)>
227func.func @strided_gather(%base : memref<100x3xf32>,
228                          %idxs : vector<4xindex>,
229                          %x : index, %y : index) -> vector<4xf32> {
230  %c0 = arith.constant 0 : index
231  %x_1 = affine.apply #map()[%x]
232  // Strided MemRef
233  %subview = memref.subview %base[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
234  %mask = arith.constant dense<true> : vector<4xi1>
235  %pass_thru = arith.constant dense<0.000000e+00> : vector<4xf32>
236  // Gather of a strided MemRef
237  %res = vector.gather %subview[%c0] [%idxs], %mask, %pass_thru : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
238  return %res : vector<4xf32>
239}
240// CHECK-LABEL:   func.func @strided_gather(
241// CHECK-SAME:                         %[[base:.*]]: memref<100x3xf32>,
242// CHECK-SAME:                         %[[IDXS:.*]]: vector<4xindex>,
243// CHECK-SAME:                         %[[VAL_4:.*]]: index,
244// CHECK-SAME:                         %[[VAL_5:.*]]: index) -> vector<4xf32> {
245// CHECK:           %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
246// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
247
248// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
249// CHECK:           %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
250
251// CHECK:           %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
252// CHECK:           %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
253// CHECK:           scf.if %[[MASK_0]] -> (vector<4xf32>)
254// CHECK:             %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
255// CHECK:             %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
256
257// CHECK:           %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
258// CHECK:           %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
259// CHECK:           scf.if %[[MASK_1]] -> (vector<4xf32>)
260// CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
261// CHECK:             %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
262
263// CHECK:           %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
264// CHECK:           %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
265// CHECK:           scf.if %[[MASK_2]] -> (vector<4xf32>)
266// CHECK:             %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
267// CHECK:             %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
268
269// CHECK:           %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
270// CHECK:           %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
271// CHECK:           scf.if %[[MASK_3]] -> (vector<4xf32>)
272// CHECK:             %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
273// CHECK:             %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
274
275// CHECK-LABEL: @scalable_gather_1d
276// CHECK-NOT: extract
277// CHECK: vector.gather
278// CHECK-NOT: extract
279func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask: vector<[2]xi1>, %pass_thru: vector<[2]xf32>) -> vector<[2]xf32> {
280  %c0 = arith.constant 0 : index
281  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
282  return %0 : vector<[2]xf32>
283}
284