xref: /llvm-project/lldb/test/API/commands/register/register/aarch64_sme_z_registers/save_restore/main.c (revision b8150c8f12fcb3c3c5e40611ddd883db1506be35)
1 #include <asm/hwcap.h>
2 #include <stdbool.h>
3 #include <stdint.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include <sys/auxv.h>
7 #include <sys/prctl.h>
8 
9 // Important details for this program:
10 // * Making a syscall will disable streaming mode if it is active.
11 // * Changing the vector length will make streaming mode and ZA inactive.
12 // * ZA can be active independent of streaming mode.
13 // * ZA's size is the streaming vector length squared.
14 
15 #ifndef PR_SME_SET_VL
16 #define PR_SME_SET_VL 63
17 #endif
18 
19 #ifndef PR_SME_GET_VL
20 #define PR_SME_GET_VL 64
21 #endif
22 
23 #ifndef PR_SME_VL_LEN_MASK
24 #define PR_SME_VL_LEN_MASK 0xffff
25 #endif
26 
27 #define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
28 #define SMSTART SM_INST(7)
29 #define SMSTART_SM SM_INST(3)
30 #define SMSTART_ZA SM_INST(5)
31 #define SMSTOP SM_INST(6)
32 #define SMSTOP_SM SM_INST(2)
33 #define SMSTOP_ZA SM_INST(4)
34 
35 #ifndef HWCAP2_SME2
36 #define HWCAP2_SME2 (1UL << 37)
37 #endif
38 
39 int start_vl = 0;
40 int other_vl = 0;
41 bool has_zt0 = false;
42 
write_sve_regs()43 void write_sve_regs() {
44   // We assume the smefa64 feature is present, which allows ffr access
45   // in streaming mode.
46   asm volatile("setffr\n\t");
47   asm volatile("ptrue p0.b\n\t");
48   asm volatile("ptrue p1.h\n\t");
49   asm volatile("ptrue p2.s\n\t");
50   asm volatile("ptrue p3.d\n\t");
51   asm volatile("pfalse p4.b\n\t");
52   asm volatile("ptrue p5.b\n\t");
53   asm volatile("ptrue p6.h\n\t");
54   asm volatile("ptrue p7.s\n\t");
55   asm volatile("ptrue p8.d\n\t");
56   asm volatile("pfalse p9.b\n\t");
57   asm volatile("ptrue p10.b\n\t");
58   asm volatile("ptrue p11.h\n\t");
59   asm volatile("ptrue p12.s\n\t");
60   asm volatile("ptrue p13.d\n\t");
61   asm volatile("pfalse p14.b\n\t");
62   asm volatile("ptrue p15.b\n\t");
63 
64   asm volatile("cpy  z0.b, p0/z, #1\n\t");
65   asm volatile("cpy  z1.b, p5/z, #2\n\t");
66   asm volatile("cpy  z2.b, p10/z, #3\n\t");
67   asm volatile("cpy  z3.b, p15/z, #4\n\t");
68   asm volatile("cpy  z4.b, p0/z, #5\n\t");
69   asm volatile("cpy  z5.b, p5/z, #6\n\t");
70   asm volatile("cpy  z6.b, p10/z, #7\n\t");
71   asm volatile("cpy  z7.b, p15/z, #8\n\t");
72   asm volatile("cpy  z8.b, p0/z, #9\n\t");
73   asm volatile("cpy  z9.b, p5/z, #10\n\t");
74   asm volatile("cpy  z10.b, p10/z, #11\n\t");
75   asm volatile("cpy  z11.b, p15/z, #12\n\t");
76   asm volatile("cpy  z12.b, p0/z, #13\n\t");
77   asm volatile("cpy  z13.b, p5/z, #14\n\t");
78   asm volatile("cpy  z14.b, p10/z, #15\n\t");
79   asm volatile("cpy  z15.b, p15/z, #16\n\t");
80   asm volatile("cpy  z16.b, p0/z, #17\n\t");
81   asm volatile("cpy  z17.b, p5/z, #18\n\t");
82   asm volatile("cpy  z18.b, p10/z, #19\n\t");
83   asm volatile("cpy  z19.b, p15/z, #20\n\t");
84   asm volatile("cpy  z20.b, p0/z, #21\n\t");
85   asm volatile("cpy  z21.b, p5/z, #22\n\t");
86   asm volatile("cpy  z22.b, p10/z, #23\n\t");
87   asm volatile("cpy  z23.b, p15/z, #24\n\t");
88   asm volatile("cpy  z24.b, p0/z, #25\n\t");
89   asm volatile("cpy  z25.b, p5/z, #26\n\t");
90   asm volatile("cpy  z26.b, p10/z, #27\n\t");
91   asm volatile("cpy  z27.b, p15/z, #28\n\t");
92   asm volatile("cpy  z28.b, p0/z, #29\n\t");
93   asm volatile("cpy  z29.b, p5/z, #30\n\t");
94   asm volatile("cpy  z30.b, p10/z, #31\n\t");
95   asm volatile("cpy  z31.b, p15/z, #32\n\t");
96 }
97 
98 // Write something different so we will know if we didn't restore them
99 // correctly.
write_sve_regs_expr()100 void write_sve_regs_expr() {
101   asm volatile("pfalse p0.b\n\t");
102   asm volatile("wrffr p0.b\n\t");
103   asm volatile("pfalse p1.b\n\t");
104   asm volatile("pfalse p2.b\n\t");
105   asm volatile("pfalse p3.b\n\t");
106   asm volatile("ptrue p4.b\n\t");
107   asm volatile("pfalse p5.b\n\t");
108   asm volatile("pfalse p6.b\n\t");
109   asm volatile("pfalse p7.b\n\t");
110   asm volatile("pfalse p8.b\n\t");
111   asm volatile("ptrue p9.b\n\t");
112   asm volatile("pfalse p10.b\n\t");
113   asm volatile("pfalse p11.b\n\t");
114   asm volatile("pfalse p12.b\n\t");
115   asm volatile("pfalse p13.b\n\t");
116   asm volatile("ptrue p14.b\n\t");
117   asm volatile("pfalse p15.b\n\t");
118 
119   asm volatile("cpy  z0.b, p0/z, #2\n\t");
120   asm volatile("cpy  z1.b, p5/z, #3\n\t");
121   asm volatile("cpy  z2.b, p10/z, #4\n\t");
122   asm volatile("cpy  z3.b, p15/z, #5\n\t");
123   asm volatile("cpy  z4.b, p0/z, #6\n\t");
124   asm volatile("cpy  z5.b, p5/z, #7\n\t");
125   asm volatile("cpy  z6.b, p10/z, #8\n\t");
126   asm volatile("cpy  z7.b, p15/z, #9\n\t");
127   asm volatile("cpy  z8.b, p0/z, #10\n\t");
128   asm volatile("cpy  z9.b, p5/z, #11\n\t");
129   asm volatile("cpy  z10.b, p10/z, #12\n\t");
130   asm volatile("cpy  z11.b, p15/z, #13\n\t");
131   asm volatile("cpy  z12.b, p0/z, #14\n\t");
132   asm volatile("cpy  z13.b, p5/z, #15\n\t");
133   asm volatile("cpy  z14.b, p10/z, #16\n\t");
134   asm volatile("cpy  z15.b, p15/z, #17\n\t");
135   asm volatile("cpy  z16.b, p0/z, #18\n\t");
136   asm volatile("cpy  z17.b, p5/z, #19\n\t");
137   asm volatile("cpy  z18.b, p10/z, #20\n\t");
138   asm volatile("cpy  z19.b, p15/z, #21\n\t");
139   asm volatile("cpy  z20.b, p0/z, #22\n\t");
140   asm volatile("cpy  z21.b, p5/z, #23\n\t");
141   asm volatile("cpy  z22.b, p10/z, #24\n\t");
142   asm volatile("cpy  z23.b, p15/z, #25\n\t");
143   asm volatile("cpy  z24.b, p0/z, #26\n\t");
144   asm volatile("cpy  z25.b, p5/z, #27\n\t");
145   asm volatile("cpy  z26.b, p10/z, #28\n\t");
146   asm volatile("cpy  z27.b, p15/z, #29\n\t");
147   asm volatile("cpy  z28.b, p0/z, #30\n\t");
148   asm volatile("cpy  z29.b, p5/z, #31\n\t");
149   asm volatile("cpy  z30.b, p10/z, #32\n\t");
150   asm volatile("cpy  z31.b, p15/z, #33\n\t");
151 }
152 
set_sme_registers(int svl,uint8_t value_offset)153 void set_sme_registers(int svl, uint8_t value_offset) {
154 #define MAX_VL_BYTES 256
155   uint8_t data[MAX_VL_BYTES];
156 
157   // ldr za will actually wrap the selected vector row, by the number of rows
158   // you have. So setting one that didn't exist would actually set one that did.
159   // That's why we need the streaming vector length here.
160   for (int i = 0; i < svl; ++i) {
161     memset(data, i + value_offset, MAX_VL_BYTES);
162     // Each one of these loads a VL sized row of ZA.
163     asm volatile("mov w12, %w0\n\t"
164                  "ldr za[w12, 0], [%1]\n\t" ::"r"(i),
165                  "r"(&data)
166                  : "w12");
167   }
168 #undef MAX_VL_BYTES
169 
170   if (has_zt0) {
171 #define ZTO_LEN (512 / 8)
172     uint8_t data[ZTO_LEN];
173     for (unsigned i = 0; i < ZTO_LEN; ++i)
174       data[i] = i + value_offset;
175 
176     asm volatile("ldr zt0, [%0]" ::"r"(&data));
177 #undef ZT0_LEN
178   }
179 }
180 
expr_disable_za()181 void expr_disable_za() {
182   SMSTOP_ZA;
183   write_sve_regs_expr();
184 }
185 
expr_enable_za()186 void expr_enable_za() {
187   SMSTART_ZA;
188   set_sme_registers(start_vl, 2);
189   write_sve_regs_expr();
190 }
191 
expr_start_vl()192 void expr_start_vl() {
193   prctl(PR_SME_SET_VL, start_vl);
194   SMSTART_ZA;
195   set_sme_registers(start_vl, 4);
196   write_sve_regs_expr();
197 }
198 
expr_other_vl()199 void expr_other_vl() {
200   prctl(PR_SME_SET_VL, other_vl);
201   SMSTART_ZA;
202   set_sme_registers(other_vl, 5);
203   write_sve_regs_expr();
204 }
205 
expr_enable_sm()206 void expr_enable_sm() {
207   SMSTART_SM;
208   write_sve_regs_expr();
209 }
210 
expr_disable_sm()211 void expr_disable_sm() {
212   SMSTOP_SM;
213   write_sve_regs_expr();
214 }
215 
main(int argc,char * argv[])216 int main(int argc, char *argv[]) {
217   // We expect to get:
218   // * whether to enable streaming mode
219   // * whether to enable ZA
220   // * what the starting VL should be
221   // * what the other VL should be
222   if (argc != 5)
223     return 1;
224 
225   bool ssve = argv[1][0] == '1';
226   bool za = argv[2][0] == '1';
227   start_vl = atoi(argv[3]);
228   other_vl = atoi(argv[4]);
229 
230   if ((getauxval(AT_HWCAP2) & HWCAP2_SME2))
231     has_zt0 = true;
232 
233   prctl(PR_SME_SET_VL, start_vl);
234 
235   if (ssve)
236     SMSTART_SM;
237 
238   if (za) {
239     SMSTART_ZA;
240     set_sme_registers(start_vl, 1);
241   }
242 
243   write_sve_regs();
244 
245   return 0; // Set a break point here.
246 }
247