xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,expand-strided-metadata,lower-affine,convert-arith-to-llvm,convert-scf-to-cf),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \
2// RUN: mlir-runner -e entry -entry-point-result=void  \
3// RUN:   -shared-libs=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true},expand-strided-metadata,lower-affine,convert-arith-to-llvm,convert-scf-to-cf),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \
7// RUN: mlir-runner -e entry -entry-point-result=void  \
8// RUN:   -shared-libs=%mlir_c_runner_utils | \
9// RUN: FileCheck %s
10
11// Test for special cases of 1D vector transfer ops.
12
13memref.global "private" @gv : memref<5x6xf32> =
14    dense<[[0. , 1. , 2. , 3. , 4. , 5. ],
15           [10., 11., 12., 13., 14., 15.],
16           [20., 21., 22., 23., 24., 25.],
17           [30., 31., 32., 33., 34., 35.],
18           [40., 41., 42., 43., 44., 45.]]>
19
20// Non-contiguous, strided load.
21func.func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
22  %fm42 = arith.constant -42.0: f32
23  %f = vector.transfer_read %A[%base1, %base2], %fm42
24      {permutation_map = affine_map<(d0, d1) -> (d0)>}
25      : memref<?x?xf32>, vector<9xf32>
26  vector.print %f: vector<9xf32>
27  return
28}
29
30// Vector load with unit stride only on last dim.
31func.func @transfer_read_1d_unit_stride(%A : memref<?x?xf32>) {
32  %c0 = arith.constant 0 : index
33  %c1 = arith.constant 1 : index
34  %c2 = arith.constant 2 : index
35  %c3 = arith.constant 3 : index
36  %c4 = arith.constant 4 : index
37  %c5 = arith.constant 5 : index
38  %c6 = arith.constant 6 : index
39  %fm42 = arith.constant -42.0: f32
40  scf.for %arg2 = %c1 to %c5 step %c2 {
41    scf.for %arg3 = %c0 to %c6 step %c3 {
42      %0 = memref.subview %A[%arg2, %arg3] [1, 2] [1, 1]
43          : memref<?x?xf32> to memref<1x2xf32, strided<[?, 1], offset: ?>>
44      %1 = vector.transfer_read %0[%c0, %c0], %fm42 {in_bounds=[true]}
45          : memref<1x2xf32, strided<[?, 1], offset: ?>>, vector<2xf32>
46      vector.print %1 : vector<2xf32>
47    }
48  }
49  return
50}
51
52// Vector load with unit stride only on last dim. Strides are not static, so
53// codegen must go through VectorToSCF 1D lowering.
54func.func @transfer_read_1d_non_static_unit_stride(%A : memref<?x?xf32>) {
55  %c1 = arith.constant 1 : index
56  %c2 = arith.constant 2 : index
57  %c4 = arith.constant 4 : index
58  %c6 = arith.constant 6 : index
59  %fm42 = arith.constant -42.0: f32
60  %1 = memref.reinterpret_cast %A to offset: [%c6], sizes: [%c4, %c6],  strides: [%c6, %c1]
61      : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
62  %2 = vector.transfer_read %1[%c2, %c1], %fm42 {in_bounds=[true]}
63      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
64  vector.print %2 : vector<4xf32>
65  return
66}
67
68// Vector load where last dim has non-unit stride.
69func.func @transfer_read_1d_non_unit_stride(%A : memref<?x?xf32>) {
70  %B = memref.reinterpret_cast %A to offset: [0], sizes: [4, 3], strides: [6, 2]
71      : memref<?x?xf32> to memref<4x3xf32, strided<[6, 2]>>
72  %c1 = arith.constant 1 : index
73  %c2 = arith.constant 2 : index
74  %fm42 = arith.constant -42.0: f32
75  %vec = vector.transfer_read %B[%c2, %c1], %fm42 {in_bounds=[false]} : memref<4x3xf32, strided<[6, 2]>>, vector<3xf32>
76  vector.print %vec : vector<3xf32>
77  return
78}
79
80// Broadcast.
81func.func @transfer_read_1d_broadcast(
82    %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
83  %fm42 = arith.constant -42.0: f32
84  %f = vector.transfer_read %A[%base1, %base2], %fm42
85      {permutation_map = affine_map<(d0, d1) -> (0)>}
86      : memref<?x?xf32>, vector<9xf32>
87  vector.print %f: vector<9xf32>
88  return
89}
90
91// Non-contiguous, strided load.
92func.func @transfer_read_1d_in_bounds(
93    %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
94  %fm42 = arith.constant -42.0: f32
95  %f = vector.transfer_read %A[%base1, %base2], %fm42
96      {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
97      : memref<?x?xf32>, vector<3xf32>
98  vector.print %f: vector<3xf32>
99  return
100}
101
102// Non-contiguous, strided load.
103func.func @transfer_read_1d_mask(
104    %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
105  %fm42 = arith.constant -42.0: f32
106  %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
107  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
108      {permutation_map = affine_map<(d0, d1) -> (d0)>}
109      : memref<?x?xf32>, vector<9xf32>
110  vector.print %f: vector<9xf32>
111  return
112}
113
114// Non-contiguous, out-of-bounds, strided load.
115func.func @transfer_read_1d_out_of_bounds(
116    %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
117  %fm42 = arith.constant -42.0: f32
118  %f = vector.transfer_read %A[%base1, %base2], %fm42
119      {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
120      : memref<?x?xf32>, vector<3xf32>
121  vector.print %f: vector<3xf32>
122  return
123}
124
125// Non-contiguous, strided load.
126func.func @transfer_read_1d_mask_in_bounds(
127    %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
128  %fm42 = arith.constant -42.0: f32
129  %mask = arith.constant dense<[1, 0, 1]> : vector<3xi1>
130  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
131      {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
132      : memref<?x?xf32>, vector<3xf32>
133  vector.print %f: vector<3xf32>
134  return
135}
136
137// Non-contiguous, strided store.
138func.func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
139  %fn1 = arith.constant -1.0 : f32
140  %vf0 = vector.splat %fn1 : vector<7xf32>
141  vector.transfer_write %vf0, %A[%base1, %base2]
142    {permutation_map = affine_map<(d0, d1) -> (d0)>}
143    : vector<7xf32>, memref<?x?xf32>
144  return
145}
146
147// Non-contiguous, strided store.
148func.func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
149  %fn1 = arith.constant -2.0 : f32
150  %vf0 = vector.splat %fn1 : vector<7xf32>
151  %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
152  vector.transfer_write %vf0, %A[%base1, %base2], %mask
153    {permutation_map = affine_map<(d0, d1) -> (d0)>}
154    : vector<7xf32>, memref<?x?xf32>
155  return
156}
157
158func.func @entry() {
159  %c0 = arith.constant 0: index
160  %c1 = arith.constant 1: index
161  %c2 = arith.constant 2: index
162  %c3 = arith.constant 3: index
163  %c10 = arith.constant 10 : index
164  %0 = memref.get_global @gv : memref<5x6xf32>
165  %A = memref.cast %0 : memref<5x6xf32> to memref<?x?xf32>
166
167  // 1. Read from 2D memref on first dimension. Cannot be lowered to an LLVM
168  //    vector load. Instead, generates scalar loads.
169  call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
170  // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
171
172  // 2.a. Read 1D vector from 2D memref with non-unit stride on first dim.
173  call @transfer_read_1d_unit_stride(%A) : (memref<?x?xf32>) -> ()
174  // CHECK: ( 10, 11 )
175  // CHECK: ( 13, 14 )
176  // CHECK: ( 30, 31 )
177  // CHECK: ( 33, 34 )
178
179  // 2.b. Read 1D vector from 2D memref with non-unit stride on first dim.
180  //      Strides are non-static.
181  call @transfer_read_1d_non_static_unit_stride(%A) : (memref<?x?xf32>) -> ()
182  // CHECK: ( 31, 32, 33, 34 )
183
184  // 2.c. Read 1D vector from 2D memref with out-of-bounds transfer dim starting
185  //      point.
186  call @transfer_read_1d_out_of_bounds(%A, %c10, %c1)
187      : (memref<?x?xf32>, index, index) -> ()
188  // CHECK: ( -42, -42, -42 )
189
190  // 3. Read 1D vector from 2D memref with non-unit stride on second dim.
191  call @transfer_read_1d_non_unit_stride(%A) : (memref<?x?xf32>) -> ()
192  // CHECK: ( 22, 24, -42 )
193
194  // 4. Write to 2D memref on first dimension. Cannot be lowered to an LLVM
195  //    vector store. Instead, generates scalar stores.
196  call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
197
198  // 5. (Same as 1. To check if 4 works correctly.)
199  call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
200  // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
201
202  // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
203  //    Generates a loop with vector.insertelement.
204  call @transfer_read_1d_broadcast(%A, %c1, %c2)
205      : (memref<?x?xf32>, index, index) -> ()
206  // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
207
208  // 7. Read from 2D memref on first dimension. Accesses are in-bounds, so no
209  //    if-check is generated inside the generated loop.
210  call @transfer_read_1d_in_bounds(%A, %c1, %c2)
211      : (memref<?x?xf32>, index, index) -> ()
212  // CHECK: ( 12, 22, -1 )
213
214  // 8. Optional mask attribute is specified and, in addition, there may be
215  //    out-of-bounds accesses.
216  call @transfer_read_1d_mask(%A, %c1, %c2)
217      : (memref<?x?xf32>, index, index) -> ()
218  // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
219
220  // 9. Same as 8, but accesses are in-bounds.
221  call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2)
222      : (memref<?x?xf32>, index, index) -> ()
223  // CHECK: ( 12, -42, -1 )
224
225  // 10. Write to 2D memref on first dimension with a mask.
226  call @transfer_write_1d_mask(%A, %c1, %c0)
227      : (memref<?x?xf32>, index, index) -> ()
228
229  // 11. (Same as 1. To check if 10 works correctly.)
230  call @transfer_read_1d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
231  // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 )
232
233  return
234}
235