xref: /llvm-project/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir (revision 96e040acee7c1728506ec49a5a229bfecd49f7db)
1// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s
2
3/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass.
4
5// -----
6
7// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
8// CHECK-SAME:                                         %[[MASK:.*]]: vector<[1]xi1>)
9func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
10  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
11  %alloca = memref.alloca() : memref<vector<[1]xi1>>
12  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
13  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
14  memref.store %mask, %alloca[] : memref<vector<[1]xi1>>
15  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
16  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1>
17  %reload = memref.load %alloca[] : memref<vector<[1]xi1>>
18  // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1>
19  return %reload : vector<[1]xi1>
20}
21
22// -----
23
24// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
25// CHECK-SAME:                                         %[[MASK:.*]]: vector<[2]xi1>)
26func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
27  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
28  %alloca = memref.alloca() : memref<vector<[2]xi1>>
29  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
30  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
31  memref.store %mask, %alloca[] : memref<vector<[2]xi1>>
32  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
33  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1>
34  %reload = memref.load %alloca[] : memref<vector<[2]xi1>>
35  // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1>
36  return %reload : vector<[2]xi1>
37}
38
39// -----
40
41// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
42// CHECK-SAME:                                         %[[MASK:.*]]: vector<[4]xi1>)
43func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
44  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
45  %alloca = memref.alloca() : memref<vector<[4]xi1>>
46  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
47  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
48  memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
49  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
50  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1>
51  %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
52  // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1>
53  return %reload : vector<[4]xi1>
54}
55
56// -----
57
58// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
59// CHECK-SAME:                                         %[[MASK:.*]]: vector<[8]xi1>)
60func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
61  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
62  %alloca = memref.alloca() : memref<vector<[8]xi1>>
63  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
64  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
65  memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
66  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
67  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
68  %reload = memref.load %alloca[] : memref<vector<[8]xi1>>
69  // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1>
70  return %reload : vector<[8]xi1>
71}
72
73// -----
74
75// CHECK-LABEL: @store_and_reload_sve_predicate_nxv16i1(
76// CHECK-SAME:                                         %[[MASK:.*]]: vector<[16]xi1>)
77func.func @store_and_reload_sve_predicate_nxv16i1(%mask: vector<[16]xi1>) -> vector<[16]xi1> {
78  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
79  %alloca = memref.alloca() : memref<vector<[16]xi1>>
80  // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
81  memref.store %mask, %alloca[] : memref<vector<[16]xi1>>
82  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
83  %reload = memref.load %alloca[] : memref<vector<[16]xi1>>
84  // CHECK-NEXT: return %[[RELOAD]] : vector<[16]xi1>
85  return %reload : vector<[16]xi1>
86}
87
88// -----
89
90/// This is not a valid SVE mask type, so is ignored by the
91// `-arm-sve-legalize-vector-storage` pass.
92
93// CHECK-LABEL: @store_and_reload_unsupported_type(
94// CHECK-SAME:                                         %[[MASK:.*]]: vector<[7]xi1>)
95func.func @store_and_reload_unsupported_type(%mask: vector<[7]xi1>) -> vector<[7]xi1> {
96  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[7]xi1>>
97  %alloca = memref.alloca() : memref<vector<[7]xi1>>
98  // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref<vector<[7]xi1>>
99  memref.store %mask, %alloca[] : memref<vector<[7]xi1>>
100  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[7]xi1>>
101  %reload = memref.load %alloca[] : memref<vector<[7]xi1>>
102  // CHECK-NEXT: return %[[RELOAD]] : vector<[7]xi1>
103  return %reload : vector<[7]xi1>
104}
105
106// -----
107
108// CHECK-LABEL: @store_2d_mask_and_reload_slice(
109// CHECK-SAME:                                  %[[MASK:.*]]: vector<3x[8]xi1>)
110func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
111  // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
112  %c0 = arith.constant 0 : index
113  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>>
114  %alloca = memref.alloca() : memref<vector<3x[8]xi1>>
115  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
116  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
117  memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>>
118  // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
119  %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
120  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>>
121  // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
122  %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>>
123  // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1>
124  return %slice : vector<[8]xi1>
125}
126
127// -----
128
129// CHECK-LABEL: @set_sve_alloca_alignment
130func.func @set_sve_alloca_alignment() {
131  /// This checks the alignment of alloca's of scalable vectors will be
132  /// something the backend can handle. Currently, the backend sets the
133  /// alignment of scalable vectors to their base size (i.e. their size at
134  /// vscale = 1). This works for hardware-sized types, which always get a
135  /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up
136  /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are
137  /// unsupported. The `-arm-sve-legalize-vector-storage` pass avoids this
138  /// issue by explicitly setting the alignment to 16-bytes for all scalable
139  /// vectors.
140
141  // CHECK-COUNT-6: alignment = 16
142  %a1 = memref.alloca() : memref<vector<[32]xi8>>
143  %a2 = memref.alloca() : memref<vector<[16]xi8>>
144  %a3 = memref.alloca() : memref<vector<[8]xi8>>
145  %a4 = memref.alloca() : memref<vector<[4]xi8>>
146  %a5 = memref.alloca() : memref<vector<[2]xi8>>
147  %a6 = memref.alloca() : memref<vector<[1]xi8>>
148
149  // CHECK-COUNT-6: alignment = 16
150  %b1 = memref.alloca() : memref<vector<[32]xi16>>
151  %b2 = memref.alloca() : memref<vector<[16]xi16>>
152  %b3 = memref.alloca() : memref<vector<[8]xi16>>
153  %b4 = memref.alloca() : memref<vector<[4]xi16>>
154  %b5 = memref.alloca() : memref<vector<[2]xi16>>
155  %b6 = memref.alloca() : memref<vector<[1]xi16>>
156
157  // CHECK-COUNT-6: alignment = 16
158  %c1 = memref.alloca() : memref<vector<[32]xi32>>
159  %c2 = memref.alloca() : memref<vector<[16]xi32>>
160  %c3 = memref.alloca() : memref<vector<[8]xi32>>
161  %c4 = memref.alloca() : memref<vector<[4]xi32>>
162  %c5 = memref.alloca() : memref<vector<[2]xi32>>
163  %c6 = memref.alloca() : memref<vector<[1]xi32>>
164
165  // CHECK-COUNT-6: alignment = 16
166  %d1 = memref.alloca() : memref<vector<[32]xi64>>
167  %d2 = memref.alloca() : memref<vector<[16]xi64>>
168  %d3 = memref.alloca() : memref<vector<[8]xi64>>
169  %d4 = memref.alloca() : memref<vector<[4]xi64>>
170  %d5 = memref.alloca() : memref<vector<[2]xi64>>
171  %d6 = memref.alloca() : memref<vector<[1]xi64>>
172
173  // CHECK-COUNT-6: alignment = 16
174  %e1 = memref.alloca() : memref<vector<[32]xf32>>
175  %e2 = memref.alloca() : memref<vector<[16]xf32>>
176  %e3 = memref.alloca() : memref<vector<[8]xf32>>
177  %e4 = memref.alloca() : memref<vector<[4]xf32>>
178  %e5 = memref.alloca() : memref<vector<[2]xf32>>
179  %e6 = memref.alloca() : memref<vector<[1]xf32>>
180
181  // CHECK-COUNT-6: alignment = 16
182  %f1 = memref.alloca() : memref<vector<[32]xf64>>
183  %f2 = memref.alloca() : memref<vector<[16]xf64>>
184  %f3 = memref.alloca() : memref<vector<[8]xf64>>
185  %f4 = memref.alloca() : memref<vector<[4]xf64>>
186  %f5 = memref.alloca() : memref<vector<[2]xf64>>
187  %f6 = memref.alloca() : memref<vector<[1]xf64>>
188
189  "prevent.dce"(
190    %a1, %a2, %a3, %a4, %a5, %a6,
191    %b1, %b2, %b3, %b4, %b5, %b6,
192    %c1, %c2, %c3, %c4, %c5, %c6,
193    %d1, %d2, %d3, %d4, %d5, %d6,
194    %e1, %e2, %e3, %e4, %e5, %e6,
195    %f1, %f2, %f3, %f4, %f5, %f6)
196    : (memref<vector<[32]xi8>>, memref<vector<[16]xi8>>, memref<vector<[8]xi8>>, memref<vector<[4]xi8>>, memref<vector<[2]xi8>>, memref<vector<[1]xi8>>,
197       memref<vector<[32]xi16>>, memref<vector<[16]xi16>>, memref<vector<[8]xi16>>, memref<vector<[4]xi16>>, memref<vector<[2]xi16>>, memref<vector<[1]xi16>>,
198       memref<vector<[32]xi32>>, memref<vector<[16]xi32>>, memref<vector<[8]xi32>>, memref<vector<[4]xi32>>, memref<vector<[2]xi32>>, memref<vector<[1]xi32>>,
199       memref<vector<[32]xi64>>, memref<vector<[16]xi64>>, memref<vector<[8]xi64>>, memref<vector<[4]xi64>>, memref<vector<[2]xi64>>, memref<vector<[1]xi64>>,
200       memref<vector<[32]xf32>>, memref<vector<[16]xf32>>, memref<vector<[8]xf32>>, memref<vector<[4]xf32>>, memref<vector<[2]xf32>>, memref<vector<[1]xf32>>,
201       memref<vector<[32]xf64>>, memref<vector<[16]xf64>>, memref<vector<[8]xf64>>, memref<vector<[4]xf64>>, memref<vector<[2]xf64>>, memref<vector<[1]xf64>>) -> ()
202  return
203}
204