xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,lower-affine,convert-scf-to-cf),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \
2// RUN: mlir-runner -e entry -entry-point-result=void  \
3// RUN:   -shared-libs=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true},lower-affine,convert-scf-to-cf),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)" | \
7// RUN: mlir-runner -e entry -entry-point-result=void  \
8// RUN:   -shared-libs=%mlir_c_runner_utils | \
9// RUN: FileCheck %s
10
11memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
12                                                       [10., 11., 12., 13.],
13                                                       [20., 21., 22., 23.]]>
14
15// Vector load.
16func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
17  %fm42 = arith.constant -42.0: f32
18  %f = vector.transfer_read %A[%base1, %base2], %fm42
19      {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
20    memref<?x?xf32>, vector<4x9xf32>
21  vector.print %f: vector<4x9xf32>
22  return
23}
24
25// Vector load with mask.
26func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
27  %fm42 = arith.constant -42.0: f32
28  %mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
29                          [0, 0, 1, 1, 1, 1, 1, 0, 1],
30                          [1, 1, 1, 1, 1, 1, 1, 0, 1],
31                          [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
32  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
33      {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
34    memref<?x?xf32>, vector<4x9xf32>
35  vector.print %f: vector<4x9xf32>
36  return
37}
38
39// Vector load with mask + transpose.
40func.func @transfer_read_2d_mask_transposed(
41    %A : memref<?x?xf32>, %base1: index, %base2: index) {
42  %fm42 = arith.constant -42.0: f32
43  %mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
44                          [0, 0, 1, 1, 1, 1, 1, 0, 1],
45                          [1, 1, 1, 1, 1, 1, 1, 0, 1],
46                          [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
47  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
48      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
49    memref<?x?xf32>, vector<9x4xf32>
50  vector.print %f: vector<9x4xf32>
51  return
52}
53
54// Vector load with mask + broadcast.
55func.func @transfer_read_2d_mask_broadcast(
56    %A : memref<?x?xf32>, %base1: index, %base2: index) {
57  %fm42 = arith.constant -42.0: f32
58  %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
59  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
60      {permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
61    memref<?x?xf32>, vector<4x9xf32>
62  vector.print %f: vector<4x9xf32>
63  return
64}
65
66// Transpose + vector load with mask + broadcast.
67func.func @transfer_read_2d_mask_transpose_broadcast_last_dim(
68    %A : memref<?x?xf32>, %base1: index, %base2: index) {
69  %fm42 = arith.constant -42.0: f32
70  %mask = arith.constant dense<[1, 0, 1, 1]> : vector<4xi1>
71  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
72      {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
73    memref<?x?xf32>, vector<4x9xf32>
74  vector.print %f: vector<4x9xf32>
75  return
76}
77
78// Load + transpose.
79func.func @transfer_read_2d_transposed(
80    %A : memref<?x?xf32>, %base1: index, %base2: index) {
81  %fm42 = arith.constant -42.0: f32
82  %f = vector.transfer_read %A[%base1, %base2], %fm42
83      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
84    memref<?x?xf32>, vector<4x9xf32>
85  vector.print %f: vector<4x9xf32>
86  return
87}
88
89// Load 1D + broadcast to 2D.
90func.func @transfer_read_2d_broadcast(
91    %A : memref<?x?xf32>, %base1: index, %base2: index) {
92  %fm42 = arith.constant -42.0: f32
93  %f = vector.transfer_read %A[%base1, %base2], %fm42
94      {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
95    memref<?x?xf32>, vector<4x9xf32>
96  vector.print %f: vector<4x9xf32>
97  return
98}
99
100// Vector store.
101func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
102  %fn1 = arith.constant -1.0 : f32
103  %vf0 = vector.splat %fn1 : vector<1x4xf32>
104  vector.transfer_write %vf0, %A[%base1, %base2]
105    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
106    vector<1x4xf32>, memref<?x?xf32>
107  return
108}
109
110// Vector store with mask.
111func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
112  %fn1 = arith.constant -2.0 : f32
113  %mask = arith.constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
114  %vf0 = vector.splat %fn1 : vector<1x4xf32>
115  vector.transfer_write %vf0, %A[%base1, %base2], %mask
116    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
117    vector<1x4xf32>, memref<?x?xf32>
118  return
119}
120
121func.func @entry() {
122  %c0 = arith.constant 0: index
123  %c1 = arith.constant 1: index
124  %c2 = arith.constant 2: index
125  %c3 = arith.constant 3: index
126  %c10 = arith.constant 10 : index
127  %0 = memref.get_global @gv : memref<3x4xf32>
128  %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
129
130  // 1.a. Read 2D vector from 2D memref.
131  call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
132  // CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
133
134  // 1.b. Read 2D vector from 2D memref. Starting position of first dim is
135  //      out-of-bounds.
136  call @transfer_read_2d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
137  // CHECK: ( ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
138
139  // 1.c. Read 2D vector from 2D memref. Starting position of second dim is
140  //      out-of-bounds.
141  call @transfer_read_2d(%A, %c1, %c10) : (memref<?x?xf32>, index, index) -> ()
142  // CHECK: ( ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
143
144  // 2. Read 2D vector from 2D memref at specified location and transpose the
145  //    result.
146  call @transfer_read_2d_transposed(%A, %c1, %c2)
147      : (memref<?x?xf32>, index, index) -> ()
148  // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
149
150  // 3. Read 2D vector from 2D memref with a 2D mask. In addition, some
151  //    accesses are out-of-bounds.
152  call @transfer_read_2d_mask(%A, %c0, %c0)
153      : (memref<?x?xf32>, index, index) -> ()
154  // CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
155
156  // 4. Same as 3, but transpose the result.
157  call @transfer_read_2d_mask_transposed(%A, %c0, %c0)
158      : (memref<?x?xf32>, index, index) -> ()
159  // CHECK: ( ( 0, -42, 20, -42 ), ( -42, -42, 21, -42 ), ( 2, 12, 22, -42 ), ( -42, 13, 23, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ) )
160
161  // 5. Read 1D vector from 2D memref at specified location and broadcast the
162  //    result to 2D.
163  call @transfer_read_2d_broadcast(%A, %c1, %c2)
164      : (memref<?x?xf32>, index, index) -> ()
165  // CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
166
167  // 6. Read 1D vector from 2D memref at specified location with mask and
168  //    broadcast the result to 2D.
169  call @transfer_read_2d_mask_broadcast(%A, %c2, %c1)
170      : (memref<?x?xf32>, index, index) -> ()
171  // CHECK: ( ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ) )
172
173  // 7. Read 1D vector from 2D memref (second dimension) at specified location
174  //    with mask and broadcast the result to 2D. In this test case, mask
175  //    elements must be evaluated before lowering to an (N>1)-D transfer.
176  call @transfer_read_2d_mask_transpose_broadcast_last_dim(%A, %c0, %c1)
177      : (memref<?x?xf32>, index, index) -> ()
178  // CHECK: ( ( 1, 1, 1, 1, 1, 1, 1, 1, 1 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( 3, 3, 3, 3, 3, 3, 3, 3, 3 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
179
180  // 8. Write 2D vector into 2D memref at specified location.
181  call @transfer_write_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
182
183  // 9. Read memref to verify step 8.
184  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
185  // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
186
187  // 10. Write 2D vector into 2D memref at specified location with mask.
188  call @transfer_write_2d_mask(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
189
190  // 11. Read memref to verify step 10.
191  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
192  // CHECK: ( ( 0, 1, -2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
193
194  return
195}
196
197