1"""
2Test the AArch64 SVE registers.
3"""
4
5from enum import Enum
6import lldb
7from lldbsuite.test.decorators import *
8from lldbsuite.test.lldbtest import *
9from lldbsuite.test import lldbutil
10
11
12class Mode(Enum):
13    SVE = 0
14    SSVE = 1
15
16
17class RegisterCommandsTestCase(TestBase):
18    def check_sve_register_size(self, set, name, expected):
19        reg_value = set.GetChildMemberWithName(name)
20        self.assertTrue(
21            reg_value.IsValid(), 'Verify we have a register named "%s"' % (name)
22        )
23        self.assertEqual(
24            reg_value.GetByteSize(), expected, 'Verify "%s" == %i' % (name, expected)
25        )
26
27    def check_sve_regs_read(self, z_reg_size, expected_mode):
28        if self.isAArch64SME():
29            # This test uses SMSTART SM, which only enables streaming mode,
30            # leaving ZA disabled.
31            expected_value = "1" if expected_mode == Mode.SSVE else "0"
32            self.expect(
33                "register read svcr", substrs=["0x000000000000000" + expected_value]
34            )
35
36        p_reg_size = int(z_reg_size / 8)
37
38        for i in range(32):
39            z_regs_value = (
40                "{"
41                + " ".join("0x{:02x}".format(i + 1) for _ in range(z_reg_size))
42                + "}"
43            )
44            self.expect("register read " + "z%i" % (i), substrs=[z_regs_value])
45
46        p_value_bytes = ["0xff", "0x55", "0x11", "0x01", "0x00"]
47        for i in range(16):
48            p_regs_value = (
49                "{" + " ".join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + "}"
50            )
51            self.expect("register read " + "p%i" % (i), substrs=[p_regs_value])
52
53        self.expect("register read ffr", substrs=[p_regs_value])
54
55    def check_sve_regs_read_after_write(self, z_reg_size):
56        p_reg_size = int(z_reg_size / 8)
57
58        z_regs_value = "{" + " ".join(("0x9d" for _ in range(z_reg_size))) + "}"
59
60        p_regs_value = "{" + " ".join(("0xee" for _ in range(p_reg_size))) + "}"
61
62        for i in range(32):
63            self.runCmd("register write " + "z%i" % (i) + " '" + z_regs_value + "'")
64
65        for i in range(32):
66            self.expect("register read " + "z%i" % (i), substrs=[z_regs_value])
67
68        for i in range(16):
69            self.runCmd("register write " + "p%i" % (i) + " '" + p_regs_value + "'")
70
71        for i in range(16):
72            self.expect("register read " + "p%i" % (i), substrs=[p_regs_value])
73
74        self.runCmd("register write " + "ffr " + "'" + p_regs_value + "'")
75
76        self.expect("register read " + "ffr", substrs=[p_regs_value])
77
78    def get_build_flags(self, mode):
79        cflags = "-march=armv8-a+sve"
80        if mode == Mode.SSVE:
81            cflags += " -DSTART_SSVE"
82        return {"CFLAGS_EXTRAS": cflags}
83
84    def skip_if_needed(self, mode):
85        if (mode == Mode.SVE) and not self.isAArch64SVE():
86            self.skipTest("SVE registers must be supported.")
87
88        if (mode == Mode.SSVE) and not self.isAArch64SMEFA64():
89            self.skipTest(
90                "SSVE registers must be supported and the smefa64 "
91                "extension must be present."
92            )
93
94    def sve_registers_configuration_impl(self, mode):
95        self.skip_if_needed(mode)
96
97        self.build(dictionary=self.get_build_flags(mode))
98        self.line = line_number("main.c", "// Set a break point here.")
99
100        exe = self.getBuildArtifact("a.out")
101        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
102
103        lldbutil.run_break_set_by_file_and_line(
104            self, "main.c", self.line, num_expected_locations=1
105        )
106        self.runCmd("run", RUN_SUCCEEDED)
107
108        self.expect(
109            "thread backtrace",
110            STOPPED_DUE_TO_BREAKPOINT,
111            substrs=["stop reason = breakpoint 1."],
112        )
113
114        target = self.dbg.GetSelectedTarget()
115        process = target.GetProcess()
116        thread = process.GetThreadAtIndex(0)
117        currentFrame = thread.GetFrameAtIndex(0)
118
119        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
120        sve_registers = registerSets.GetFirstValueByName(
121            "Scalable Vector Extension Registers"
122        )
123        self.assertTrue(sve_registers)
124
125        vg_reg_value = sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
126
127        z_reg_size = vg_reg_value * 8
128        for i in range(32):
129            self.check_sve_register_size(sve_registers, "z%i" % (i), z_reg_size)
130
131        p_reg_size = z_reg_size / 8
132        for i in range(16):
133            self.check_sve_register_size(sve_registers, "p%i" % (i), p_reg_size)
134
135        self.check_sve_register_size(sve_registers, "ffr", p_reg_size)
136
137    @no_debug_info_test
138    @skipIf(archs=no_match(["aarch64"]))
139    @skipIf(oslist=no_match(["linux"]))
140    def test_sve_registers_configuration(self):
141        """Test AArch64 SVE registers size configuration."""
142        self.sve_registers_configuration_impl(Mode.SVE)
143
144    @no_debug_info_test
145    @skipIf(archs=no_match(["aarch64"]))
146    @skipIf(oslist=no_match(["linux"]))
147    def test_ssve_registers_configuration(self):
148        """Test AArch64 SSVE registers size configuration."""
149        self.sve_registers_configuration_impl(Mode.SSVE)
150
151    def sve_registers_read_write_impl(self, start_mode, eval_mode):
152        self.skip_if_needed(start_mode)
153        self.skip_if_needed(eval_mode)
154        self.build(dictionary=self.get_build_flags(start_mode))
155
156        exe = self.getBuildArtifact("a.out")
157        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
158
159        self.line = line_number("main.c", "// Set a break point here.")
160        lldbutil.run_break_set_by_file_and_line(
161            self, "main.c", self.line, num_expected_locations=1
162        )
163        self.runCmd("run", RUN_SUCCEEDED)
164
165        self.expect(
166            "thread backtrace",
167            STOPPED_DUE_TO_BREAKPOINT,
168            substrs=["stop reason = breakpoint 1."],
169        )
170
171        target = self.dbg.GetSelectedTarget()
172        process = target.GetProcess()
173
174        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
175        sve_registers = registerSets.GetFirstValueByName(
176            "Scalable Vector Extension Registers"
177        )
178        self.assertTrue(sve_registers)
179
180        vg_reg_value = sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
181        z_reg_size = vg_reg_value * 8
182        self.check_sve_regs_read(z_reg_size, start_mode)
183
184        # Evaluate simple expression and print function expr_eval_func address.
185        self.expect("expression expr_eval_func", substrs=["= 0x"])
186
187        # Evaluate expression call function expr_eval_func.
188        self.expect_expr(
189            "expr_eval_func({})".format(
190                "true" if (eval_mode == Mode.SSVE) else "false"
191            ),
192            result_type="int",
193            result_value="1",
194        )
195
196        # We called a jitted function above which must not have changed SVE
197        # vector length or register values.
198        self.check_sve_regs_read(z_reg_size, start_mode)
199
200        self.check_sve_regs_read_after_write(z_reg_size)
201
202    # The following tests all setup some register values then evaluate an
203    # expression. After the expression, the mode and register values should be
204    # the same as before. Finally they read/write some values in the registers.
205    # The only difference is the mode we start the program in, and the mode
206    # the expression function uses.
207
208    @no_debug_info_test
209    @skipIf(archs=no_match(["aarch64"]))
210    @skipIf(oslist=no_match(["linux"]))
211    def test_registers_expr_read_write_sve_sve(self):
212        self.sve_registers_read_write_impl(Mode.SVE, Mode.SVE)
213
214    @no_debug_info_test
215    @skipIf(archs=no_match(["aarch64"]))
216    @skipIf(oslist=no_match(["linux"]))
217    def test_registers_expr_read_write_ssve_ssve(self):
218        self.sve_registers_read_write_impl(Mode.SSVE, Mode.SSVE)
219
220    @no_debug_info_test
221    @skipIf(archs=no_match(["aarch64"]))
222    @skipIf(oslist=no_match(["linux"]))
223    def test_registers_expr_read_write_sve_ssve(self):
224        self.sve_registers_read_write_impl(Mode.SVE, Mode.SSVE)
225
226    @no_debug_info_test
227    @skipIf(archs=no_match(["aarch64"]))
228    @skipIf(oslist=no_match(["linux"]))
229    def test_registers_expr_read_write_ssve_sve(self):
230        self.sve_registers_read_write_impl(Mode.SSVE, Mode.SVE)
231