1"""
2Test AArch64 dynamic register sets
3"""
4
5import lldb
6from lldbsuite.test.decorators import *
7from lldbsuite.test.lldbtest import *
8from lldbsuite.test import lldbutil
9
10
11class RegisterCommandsTestCase(TestBase):
12    def check_sve_register_size(self, set, name, expected):
13        reg_value = set.GetChildMemberWithName(name)
14        self.assertTrue(reg_value.IsValid(), "Expected a register named %s" % (name))
15        self.assertEqual(
16            reg_value.GetByteSize(),
17            expected,
18            "Expected a register %s size == %i bytes" % (name, expected),
19        )
20
21    def sve_regs_read_dynamic(self, sve_registers):
22        vg_reg = sve_registers.GetChildMemberWithName("vg")
23        vg_reg_value = sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
24
25        z_reg_size = vg_reg_value * 8
26        p_reg_size = int(z_reg_size / 8)
27
28        for i in range(32):
29            z_regs_value = (
30                "{"
31                + " ".join("0x{:02x}".format(i + 1) for _ in range(z_reg_size))
32                + "}"
33            )
34            self.expect("register read z%i" % (i), substrs=[z_regs_value])
35
36        # Set P registers with random test values. The P registers are predicate
37        # registers, which hold one bit for each byte available in a Z register.
38        # For below mentioned values of P registers, P(0,5,10,15) will have all
39        # Z register lanes set while P(4,9,14) will have no lanes set.
40        p_value_bytes = ["0xff", "0x55", "0x11", "0x01", "0x00"]
41        for i in range(16):
42            p_regs_value = (
43                "{" + " ".join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + "}"
44            )
45            self.expect("register read p%i" % (i), substrs=[p_regs_value])
46
47        self.expect("register read ffr", substrs=[p_regs_value])
48
49        for i in range(32):
50            z_regs_value = (
51                "{"
52                + " ".join("0x{:02x}".format(32 - i) for _ in range(z_reg_size))
53                + "}"
54            )
55            self.runCmd("register write z%i '%s'" % (i, z_regs_value))
56            self.expect("register read z%i" % (i), substrs=[z_regs_value])
57
58        for i in range(16):
59            p_regs_value = (
60                "{"
61                + " ".join("0x{:02x}".format(16 - i) for _ in range(p_reg_size))
62                + "}"
63            )
64            self.runCmd("register write p%i '%s'" % (i, p_regs_value))
65            self.expect("register read p%i" % (i), substrs=[p_regs_value])
66
67        p_regs_value = (
68            "{" + " ".join("0x{:02x}".format(8) for _ in range(p_reg_size)) + "}"
69        )
70        self.runCmd("register write ffr " + "'" + p_regs_value + "'")
71        self.expect("register read ffr", substrs=[p_regs_value])
72
73    def setup_register_config_test(self, run_args=None):
74        self.build()
75        self.line = line_number("main.c", "// Set a break point here.")
76
77        exe = self.getBuildArtifact("a.out")
78        if run_args is not None:
79            self.runCmd("settings set target.run-args " + run_args)
80        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
81
82        lldbutil.run_break_set_by_file_and_line(
83            self, "main.c", self.line, num_expected_locations=1
84        )
85        self.runCmd("run", RUN_SUCCEEDED)
86
87        self.expect(
88            "thread backtrace",
89            STOPPED_DUE_TO_BREAKPOINT,
90            substrs=["stop reason = breakpoint 1."],
91        )
92
93        return self.thread().GetSelectedFrame().GetRegisters()
94
95    @no_debug_info_test
96    @skipIf(archs=no_match(["aarch64"]))
97    @skipIf(oslist=no_match(["linux"]))
98    def test_aarch64_dynamic_regset_config(self):
99        """Test AArch64 Dynamic Register sets configuration."""
100        register_sets = self.setup_register_config_test()
101
102        for registerSet in register_sets:
103            if "Scalable Vector Extension Registers" in registerSet.GetName():
104                self.assertTrue(
105                    self.isAArch64SVE(),
106                    "LLDB enabled AArch64 SVE register set when it was disabled by target.",
107                )
108                self.sve_regs_read_dynamic(registerSet)
109            if "MTE Control Register" in registerSet.GetName():
110                self.assertTrue(
111                    self.isAArch64MTE(),
112                    "LLDB enabled AArch64 MTE register set when it was disabled by target.",
113                )
114                self.runCmd("register write mte_ctrl 0x7fff9")
115                self.expect(
116                    "register read mte_ctrl", substrs=["mte_ctrl = 0x000000000007fff9"]
117                )
118            if "Pointer Authentication Registers" in registerSet.GetName():
119                self.assertTrue(
120                    self.isAArch64PAuth(),
121                    "LLDB enabled AArch64 Pointer Authentication register set when it was disabled by target.",
122                )
123                self.expect("register read data_mask", substrs=["data_mask = 0x"])
124                self.expect("register read code_mask", substrs=["code_mask = 0x"])
125            if "Scalable Matrix Extension Registers" in registerSet.GetName():
126                self.assertTrue(
127                    self.isAArch64SME(),
128                    "LLDB Enabled SME register set when it was disabled by target",
129                )
130
131    def make_za_value(self, vl, generator):
132        # Generate a vector value string "{0x00 0x01....}".
133        rows = []
134        for row in range(vl):
135            byte = "0x{:02x}".format(generator(row))
136            rows.append(" ".join([byte] * vl))
137        return "{" + " ".join(rows) + "}"
138
139    def make_zt0_value(self, generator):
140        num_bytes = 512 // 8
141        elements = []
142        for i in range(num_bytes):
143            elements.append("0x{:02x}".format(generator(i)))
144
145        return "{" + " ".join(elements) + "}"
146
147    @no_debug_info_test
148    @skipIf(archs=no_match(["aarch64"]))
149    @skipIf(oslist=no_match(["linux"]))
150    def test_aarch64_dynamic_regset_config_sme(self):
151        """Test AArch64 Dynamic Register sets configuration, but only SME
152        registers."""
153        if not self.isAArch64SMEFA64():
154            self.skipTest("SME and the smefa64 extension must be present")
155
156        register_sets = self.setup_register_config_test("sme")
157
158        ssve_registers = register_sets.GetFirstValueByName(
159            "Scalable Vector Extension Registers"
160        )
161        self.assertTrue(ssve_registers.IsValid())
162        self.sve_regs_read_dynamic(ssve_registers)
163
164        sme_registers = register_sets.GetFirstValueByName(
165            "Scalable Matrix Extension Registers"
166        )
167        self.assertTrue(sme_registers.IsValid())
168
169        vg = ssve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
170        vl = vg * 8
171        # When first enabled it is all 0s.
172        self.expect("register read za", substrs=[self.make_za_value(vl, lambda r: 0)])
173        za_value = self.make_za_value(vl, lambda r: r + 1)
174        self.runCmd("register write za '{}'".format(za_value))
175        self.expect("register read za", substrs=[za_value])
176
177        # SVG should match VG because we're in streaming mode.
178
179        self.assertTrue(sme_registers.IsValid())
180        svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
181        self.assertEqual(vg, svg)
182
183        # SVCR should be SVCR.SM | SVCR.ZA aka 3 because streaming mode is on
184        # and ZA is enabled.
185        svcr = sme_registers.GetChildMemberWithName("svcr").GetValueAsUnsigned()
186        self.assertEqual(3, svcr)
187
188        # SVCR is read only so we do not test writing to it.
189
190    def write_to_enable_za_test(self, has_zt0, write_za_first):
191        # Run a test where we start with ZA disabled, and write to either ZA
192        # or ZT0 which causes them to become enabled.
193
194        # No argument, so ZA and ZT0 will be disabled when we break.
195        register_sets = self.setup_register_config_test()
196
197        # vg is the non-streaming vg as we are in non-streaming mode, so we need
198        # to use svg.
199        sme_registers = register_sets.GetFirstValueByName(
200            "Scalable Matrix Extension Registers"
201        )
202        self.assertTrue(sme_registers.IsValid())
203        svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
204
205        # We are not in streaming mode, ZA is disabled, so this should be 0.
206        svcr = sme_registers.GetChildMemberWithName("svcr").GetValueAsUnsigned()
207        self.assertEqual(0, svcr)
208
209        svl = svg * 8
210        # A disabled ZA is shown as all 0s.
211        disabled_za = self.make_za_value(svl, lambda r: 0)
212        self.expect("register read za", substrs=[disabled_za])
213
214        disabled_zt0 = self.make_zt0_value(lambda n: 0)
215        if has_zt0:
216            # A disabled zt0 is all 0s.
217            self.expect("register read zt0", substrs=[disabled_zt0])
218
219        # Writing to ZA or ZTO enables both and we should be able to read the
220        # value back.
221        za_value = self.make_za_value(svl, lambda r: r + 1)
222        zt0_value = self.make_zt0_value(lambda n: n + 1)
223
224        if write_za_first:
225            # This enables ZA and ZT0.
226            self.runCmd("register write za '{}'".format(za_value))
227            self.expect("register read za", substrs=[za_value])
228
229            if has_zt0:
230                # ZT0 is still 0s at this point, though it is active.
231                self.expect("register read zt0", substrs=[disabled_zt0])
232
233                # Now write ZT0 to we can check it reads back correctly.
234                self.runCmd("register write zt0 '{}'".format(zt0_value))
235                self.expect("register read zt0", substrs=[zt0_value])
236        else:
237            if not has_zt0:
238                self.fail("Cannot write to zt0 when sme2 is not present.")
239
240            # Instead use the write of ZT0 to activate ZA.
241            self.runCmd("register write zt0 '{}'".format(zt0_value))
242            self.expect("register read zt0", substrs=[zt0_value])
243
244            # ZA will be active but 0s at this point, but it is active.
245            self.expect("register read zt0", substrs=[disabled_za])
246
247            # Write and read back ZA.
248            self.runCmd("register write za '{}'".format(za_value))
249            self.expect("register read za", substrs=[za_value])
250
251        # Now SVCR.ZA should be set, which is bit 1.
252        self.expect("register read svcr", substrs=["0x0000000000000002"])
253
254        # SVCR is read only so we do not test writing to it.
255
256    @no_debug_info_test
257    @skipIf(archs=no_match(["aarch64"]))
258    @skipIf(oslist=no_match(["linux"]))
259    def test_aarch64_dynamic_regset_config_sme_write_za_to_enable(self):
260        """Test that ZA and ZT0 (if present) shows as 0s when disabled and
261        can be enabled by writing to ZA."""
262        if not self.isAArch64SME():
263            self.skipTest("SME must be present.")
264
265        self.write_to_enable_za_test(self.isAArch64SME2(), True)
266
267    @no_debug_info_test
268    @skipIf(archs=no_match(["aarch64"]))
269    @skipIf(oslist=no_match(["linux"]))
270    def test_aarch64_dynamic_regset_config_sme_write_zt0_to_enable(self):
271        """Test that ZA and ZT0 (if present) shows as 0s when disabled and
272        can be enabled by writing to ZT0."""
273        if not self.isAArch64SME():
274            self.skipTest("SME must be present.")
275        if not self.isAArch64SME2():
276            self.skipTest("SME2 must be present.")
277
278        self.write_to_enable_za_test(True, True)
279