xref: /llvm-project/lldb/packages/Python/lldbsuite/test/lldbinline.py (revision e634c2f7149392b62e93c1b2b75701a12bc06721)
1# System modules
2import os
3import textwrap
4
5# Third-party modules
6import io
7
8# LLDB modules
9import lldb
10from .lldbtest import *
11from . import configuration
12from . import lldbutil
13from .decorators import *
14
15
16def source_type(filename):
17    _, extension = os.path.splitext(filename)
18    return {
19        ".c": "C_SOURCES",
20        ".cpp": "CXX_SOURCES",
21        ".cxx": "CXX_SOURCES",
22        ".cc": "CXX_SOURCES",
23        ".m": "OBJC_SOURCES",
24        ".mm": "OBJCXX_SOURCES",
25    }.get(extension, None)
26
27
28class CommandParser:
29    def __init__(self):
30        self.breakpoints = []
31
32    def parse_one_command(self, line):
33        parts = line.split("//%")
34
35        command = None
36        new_breakpoint = True
37
38        if len(parts) == 2:
39            command = parts[1].rstrip()
40            new_breakpoint = parts[0].strip() != ""
41
42        return (command, new_breakpoint)
43
44    def parse_source_files(self, source_files):
45        for source_file in source_files:
46            file_handle = io.open(source_file, encoding="utf-8")
47            lines = file_handle.readlines()
48            line_number = 0
49            # non-NULL means we're looking through whitespace to find
50            # additional commands
51            current_breakpoint = None
52            for line in lines:
53                line_number = line_number + 1  # 1-based, so we do this first
54                (command, new_breakpoint) = self.parse_one_command(line)
55
56                if new_breakpoint:
57                    current_breakpoint = None
58
59                if command is not None:
60                    if current_breakpoint is None:
61                        current_breakpoint = {}
62                        current_breakpoint["file_name"] = source_file
63                        current_breakpoint["line_number"] = line_number
64                        current_breakpoint["command"] = command
65                        self.breakpoints.append(current_breakpoint)
66                    else:
67                        current_breakpoint["command"] = (
68                            current_breakpoint["command"] + "\n" + command
69                        )
70        for bkpt in self.breakpoints:
71            bkpt["command"] = textwrap.dedent(bkpt["command"])
72
73    def set_breakpoints(self, target):
74        for breakpoint in self.breakpoints:
75            breakpoint["breakpoint"] = target.BreakpointCreateByLocation(
76                breakpoint["file_name"], breakpoint["line_number"]
77            )
78
79    def handle_breakpoint(self, test, breakpoint_id):
80        for breakpoint in self.breakpoints:
81            if breakpoint["breakpoint"].GetID() == breakpoint_id:
82                test.execute_user_command(breakpoint["command"])
83                return
84
85
86class InlineTest(TestBase):
87    def getBuildDirBasename(self):
88        return self.__class__.__name__ + "." + self.testMethodName
89
90    def BuildMakefile(self):
91        makefilePath = self.getBuildArtifact("Makefile")
92        if os.path.exists(makefilePath):
93            return
94
95        categories = {}
96        for f in os.listdir(self.getSourceDir()):
97            t = source_type(f)
98            if t:
99                if t in list(categories.keys()):
100                    categories[t].append(f)
101                else:
102                    categories[t] = [f]
103
104        with open(makefilePath, "w+") as makefile:
105            for t in list(categories.keys()):
106                line = t + " := " + " ".join(categories[t])
107                makefile.write(line + "\n")
108
109            if ("OBJCXX_SOURCES" in list(categories.keys())) or (
110                "OBJC_SOURCES" in list(categories.keys())
111            ):
112                makefile.write("LDFLAGS = $(CFLAGS) -lobjc -framework Foundation\n")
113
114            if "CXX_SOURCES" in list(categories.keys()):
115                makefile.write("CXXFLAGS += -std=c++11\n")
116
117            makefile.write("include Makefile.rules\n")
118
119    def _test(self):
120        self.BuildMakefile()
121        self.build(dictionary=self._build_dict)
122        self.do_test()
123
124    def execute_user_command(self, __command):
125        exec(__command, globals(), locals())
126
127    def _get_breakpoint_ids(self, thread):
128        ids = set()
129        for i in range(0, thread.GetStopReasonDataCount(), 2):
130            ids.add(thread.GetStopReasonDataAtIndex(i))
131        self.assertGreater(len(ids), 0)
132        return sorted(ids)
133
134    def do_test(self):
135        exe = self.getBuildArtifact("a.out")
136        source_files = [f for f in os.listdir(self.getSourceDir()) if source_type(f)]
137        target = self.dbg.CreateTarget(exe)
138
139        parser = CommandParser()
140        parser.parse_source_files(source_files)
141        parser.set_breakpoints(target)
142
143        process = target.LaunchSimple(None, None, self.get_process_working_directory())
144        self.assertIsNotNone(process, PROCESS_IS_VALID)
145
146        hit_breakpoints = 0
147
148        while lldbutil.get_stopped_thread(process, lldb.eStopReasonBreakpoint):
149            hit_breakpoints += 1
150            thread = lldbutil.get_stopped_thread(process, lldb.eStopReasonBreakpoint)
151            for bp_id in self._get_breakpoint_ids(thread):
152                parser.handle_breakpoint(self, bp_id)
153            process.Continue()
154
155        self.assertTrue(
156            hit_breakpoints > 0, "inline test did not hit a single breakpoint"
157        )
158        # Either the process exited or the stepping plan is complete.
159        self.assertTrue(
160            process.GetState() in [lldb.eStateStopped, lldb.eStateExited],
161            PROCESS_EXITED,
162        )
163
164    def check_expression(self, expression, expected_result, use_summary=True):
165        value = self.frame().EvaluateExpression(expression)
166        self.assertTrue(value.IsValid(), expression + "returned a valid value")
167        if self.TraceOn():
168            print(value.GetSummary())
169            print(value.GetValue())
170        if use_summary:
171            answer = value.GetSummary()
172        else:
173            answer = value.GetValue()
174        report_str = "%s expected: %s got: %s" % (expression, expected_result, answer)
175        self.assertTrue(answer == expected_result, report_str)
176
177
178def ApplyDecoratorsToFunction(func, decorators):
179    tmp = func
180    if isinstance(decorators, list):
181        for decorator in decorators:
182            tmp = decorator(tmp)
183    elif hasattr(decorators, "__call__"):
184        tmp = decorators(tmp)
185    return tmp
186
187
188def MakeInlineTest(__file, __globals, decorators=None, name=None, build_dict=None):
189    # Adjust the filename if it ends in .pyc.  We want filenames to
190    # reflect the source python file, not the compiled variant.
191    if __file is not None and __file.endswith(".pyc"):
192        # Strip the trailing "c"
193        __file = __file[0:-1]
194
195    if name is None:
196        # Derive the test name from the current file name
197        file_basename = os.path.basename(__file)
198        name, _ = os.path.splitext(file_basename)
199
200    test_func = ApplyDecoratorsToFunction(InlineTest._test, decorators)
201    # Build the test case
202    test_class = type(
203        name, (InlineTest,), dict(test=test_func, name=name, _build_dict=build_dict)
204    )
205
206    # Add the test case to the globals, and hide InlineTest
207    __globals.update({name: test_class})
208
209    # Keep track of the original test filename so we report it
210    # correctly in test results.
211    test_class.test_filename = __file
212    test_class.mydir = TestBase.compute_mydir(__file)
213    return test_class
214