1 #include <asm/hwcap.h>
2 #include <stdbool.h>
3 #include <stdint.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include <sys/auxv.h>
7 #include <sys/prctl.h>
8
9 // Important details for this program:
10 // * Making a syscall will disable streaming mode if it is active.
11 // * Changing the vector length will make streaming mode and ZA inactive.
12 // * ZA can be active independent of streaming mode.
13 // * ZA's size is the streaming vector length squared.
14
15 #ifndef PR_SME_SET_VL
16 #define PR_SME_SET_VL 63
17 #endif
18
19 #ifndef PR_SME_GET_VL
20 #define PR_SME_GET_VL 64
21 #endif
22
23 #ifndef PR_SME_VL_LEN_MASK
24 #define PR_SME_VL_LEN_MASK 0xffff
25 #endif
26
27 #define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
28 #define SMSTART SM_INST(7)
29 #define SMSTART_SM SM_INST(3)
30 #define SMSTART_ZA SM_INST(5)
31 #define SMSTOP SM_INST(6)
32 #define SMSTOP_SM SM_INST(2)
33 #define SMSTOP_ZA SM_INST(4)
34
35 #ifndef HWCAP2_SME2
36 #define HWCAP2_SME2 (1UL << 37)
37 #endif
38
39 int start_vl = 0;
40 int other_vl = 0;
41 bool has_zt0 = false;
42
write_sve_regs()43 void write_sve_regs() {
44 // We assume the smefa64 feature is present, which allows ffr access
45 // in streaming mode.
46 asm volatile("setffr\n\t");
47 asm volatile("ptrue p0.b\n\t");
48 asm volatile("ptrue p1.h\n\t");
49 asm volatile("ptrue p2.s\n\t");
50 asm volatile("ptrue p3.d\n\t");
51 asm volatile("pfalse p4.b\n\t");
52 asm volatile("ptrue p5.b\n\t");
53 asm volatile("ptrue p6.h\n\t");
54 asm volatile("ptrue p7.s\n\t");
55 asm volatile("ptrue p8.d\n\t");
56 asm volatile("pfalse p9.b\n\t");
57 asm volatile("ptrue p10.b\n\t");
58 asm volatile("ptrue p11.h\n\t");
59 asm volatile("ptrue p12.s\n\t");
60 asm volatile("ptrue p13.d\n\t");
61 asm volatile("pfalse p14.b\n\t");
62 asm volatile("ptrue p15.b\n\t");
63
64 asm volatile("cpy z0.b, p0/z, #1\n\t");
65 asm volatile("cpy z1.b, p5/z, #2\n\t");
66 asm volatile("cpy z2.b, p10/z, #3\n\t");
67 asm volatile("cpy z3.b, p15/z, #4\n\t");
68 asm volatile("cpy z4.b, p0/z, #5\n\t");
69 asm volatile("cpy z5.b, p5/z, #6\n\t");
70 asm volatile("cpy z6.b, p10/z, #7\n\t");
71 asm volatile("cpy z7.b, p15/z, #8\n\t");
72 asm volatile("cpy z8.b, p0/z, #9\n\t");
73 asm volatile("cpy z9.b, p5/z, #10\n\t");
74 asm volatile("cpy z10.b, p10/z, #11\n\t");
75 asm volatile("cpy z11.b, p15/z, #12\n\t");
76 asm volatile("cpy z12.b, p0/z, #13\n\t");
77 asm volatile("cpy z13.b, p5/z, #14\n\t");
78 asm volatile("cpy z14.b, p10/z, #15\n\t");
79 asm volatile("cpy z15.b, p15/z, #16\n\t");
80 asm volatile("cpy z16.b, p0/z, #17\n\t");
81 asm volatile("cpy z17.b, p5/z, #18\n\t");
82 asm volatile("cpy z18.b, p10/z, #19\n\t");
83 asm volatile("cpy z19.b, p15/z, #20\n\t");
84 asm volatile("cpy z20.b, p0/z, #21\n\t");
85 asm volatile("cpy z21.b, p5/z, #22\n\t");
86 asm volatile("cpy z22.b, p10/z, #23\n\t");
87 asm volatile("cpy z23.b, p15/z, #24\n\t");
88 asm volatile("cpy z24.b, p0/z, #25\n\t");
89 asm volatile("cpy z25.b, p5/z, #26\n\t");
90 asm volatile("cpy z26.b, p10/z, #27\n\t");
91 asm volatile("cpy z27.b, p15/z, #28\n\t");
92 asm volatile("cpy z28.b, p0/z, #29\n\t");
93 asm volatile("cpy z29.b, p5/z, #30\n\t");
94 asm volatile("cpy z30.b, p10/z, #31\n\t");
95 asm volatile("cpy z31.b, p15/z, #32\n\t");
96 }
97
98 // Write something different so we will know if we didn't restore them
99 // correctly.
write_sve_regs_expr()100 void write_sve_regs_expr() {
101 asm volatile("pfalse p0.b\n\t");
102 asm volatile("wrffr p0.b\n\t");
103 asm volatile("pfalse p1.b\n\t");
104 asm volatile("pfalse p2.b\n\t");
105 asm volatile("pfalse p3.b\n\t");
106 asm volatile("ptrue p4.b\n\t");
107 asm volatile("pfalse p5.b\n\t");
108 asm volatile("pfalse p6.b\n\t");
109 asm volatile("pfalse p7.b\n\t");
110 asm volatile("pfalse p8.b\n\t");
111 asm volatile("ptrue p9.b\n\t");
112 asm volatile("pfalse p10.b\n\t");
113 asm volatile("pfalse p11.b\n\t");
114 asm volatile("pfalse p12.b\n\t");
115 asm volatile("pfalse p13.b\n\t");
116 asm volatile("ptrue p14.b\n\t");
117 asm volatile("pfalse p15.b\n\t");
118
119 asm volatile("cpy z0.b, p0/z, #2\n\t");
120 asm volatile("cpy z1.b, p5/z, #3\n\t");
121 asm volatile("cpy z2.b, p10/z, #4\n\t");
122 asm volatile("cpy z3.b, p15/z, #5\n\t");
123 asm volatile("cpy z4.b, p0/z, #6\n\t");
124 asm volatile("cpy z5.b, p5/z, #7\n\t");
125 asm volatile("cpy z6.b, p10/z, #8\n\t");
126 asm volatile("cpy z7.b, p15/z, #9\n\t");
127 asm volatile("cpy z8.b, p0/z, #10\n\t");
128 asm volatile("cpy z9.b, p5/z, #11\n\t");
129 asm volatile("cpy z10.b, p10/z, #12\n\t");
130 asm volatile("cpy z11.b, p15/z, #13\n\t");
131 asm volatile("cpy z12.b, p0/z, #14\n\t");
132 asm volatile("cpy z13.b, p5/z, #15\n\t");
133 asm volatile("cpy z14.b, p10/z, #16\n\t");
134 asm volatile("cpy z15.b, p15/z, #17\n\t");
135 asm volatile("cpy z16.b, p0/z, #18\n\t");
136 asm volatile("cpy z17.b, p5/z, #19\n\t");
137 asm volatile("cpy z18.b, p10/z, #20\n\t");
138 asm volatile("cpy z19.b, p15/z, #21\n\t");
139 asm volatile("cpy z20.b, p0/z, #22\n\t");
140 asm volatile("cpy z21.b, p5/z, #23\n\t");
141 asm volatile("cpy z22.b, p10/z, #24\n\t");
142 asm volatile("cpy z23.b, p15/z, #25\n\t");
143 asm volatile("cpy z24.b, p0/z, #26\n\t");
144 asm volatile("cpy z25.b, p5/z, #27\n\t");
145 asm volatile("cpy z26.b, p10/z, #28\n\t");
146 asm volatile("cpy z27.b, p15/z, #29\n\t");
147 asm volatile("cpy z28.b, p0/z, #30\n\t");
148 asm volatile("cpy z29.b, p5/z, #31\n\t");
149 asm volatile("cpy z30.b, p10/z, #32\n\t");
150 asm volatile("cpy z31.b, p15/z, #33\n\t");
151 }
152
set_sme_registers(int svl,uint8_t value_offset)153 void set_sme_registers(int svl, uint8_t value_offset) {
154 #define MAX_VL_BYTES 256
155 uint8_t data[MAX_VL_BYTES];
156
157 // ldr za will actually wrap the selected vector row, by the number of rows
158 // you have. So setting one that didn't exist would actually set one that did.
159 // That's why we need the streaming vector length here.
160 for (int i = 0; i < svl; ++i) {
161 memset(data, i + value_offset, MAX_VL_BYTES);
162 // Each one of these loads a VL sized row of ZA.
163 asm volatile("mov w12, %w0\n\t"
164 "ldr za[w12, 0], [%1]\n\t" ::"r"(i),
165 "r"(&data)
166 : "w12");
167 }
168 #undef MAX_VL_BYTES
169
170 if (has_zt0) {
171 #define ZTO_LEN (512 / 8)
172 uint8_t data[ZTO_LEN];
173 for (unsigned i = 0; i < ZTO_LEN; ++i)
174 data[i] = i + value_offset;
175
176 asm volatile("ldr zt0, [%0]" ::"r"(&data));
177 #undef ZT0_LEN
178 }
179 }
180
expr_disable_za()181 void expr_disable_za() {
182 SMSTOP_ZA;
183 write_sve_regs_expr();
184 }
185
expr_enable_za()186 void expr_enable_za() {
187 SMSTART_ZA;
188 set_sme_registers(start_vl, 2);
189 write_sve_regs_expr();
190 }
191
expr_start_vl()192 void expr_start_vl() {
193 prctl(PR_SME_SET_VL, start_vl);
194 SMSTART_ZA;
195 set_sme_registers(start_vl, 4);
196 write_sve_regs_expr();
197 }
198
expr_other_vl()199 void expr_other_vl() {
200 prctl(PR_SME_SET_VL, other_vl);
201 SMSTART_ZA;
202 set_sme_registers(other_vl, 5);
203 write_sve_regs_expr();
204 }
205
expr_enable_sm()206 void expr_enable_sm() {
207 SMSTART_SM;
208 write_sve_regs_expr();
209 }
210
expr_disable_sm()211 void expr_disable_sm() {
212 SMSTOP_SM;
213 write_sve_regs_expr();
214 }
215
main(int argc,char * argv[])216 int main(int argc, char *argv[]) {
217 // We expect to get:
218 // * whether to enable streaming mode
219 // * whether to enable ZA
220 // * what the starting VL should be
221 // * what the other VL should be
222 if (argc != 5)
223 return 1;
224
225 bool ssve = argv[1][0] == '1';
226 bool za = argv[2][0] == '1';
227 start_vl = atoi(argv[3]);
228 other_vl = atoi(argv[4]);
229
230 if ((getauxval(AT_HWCAP2) & HWCAP2_SME2))
231 has_zt0 = true;
232
233 prctl(PR_SME_SET_VL, start_vl);
234
235 if (ssve)
236 SMSTART_SM;
237
238 if (za) {
239 SMSTART_ZA;
240 set_sme_registers(start_vl, 1);
241 }
242
243 write_sve_regs();
244
245 return 0; // Set a break point here.
246 }
247