1# Usage:
2# ./bin/lldb $LLVM/lldb/test/API/functionalities/interactive_scripted_process/main \
3#   -o "br set -p 'Break here'" \
4#   -o "command script import $LLVM/lldb/test/API/functionalities/interactive_scripted_process/interactive_scripted_process.py" \
5#   -o "create_mux" \
6#   -o "create_sub" \
7#   -o "br set -p 'also break here'" -o 'continue'
8
9import os, json, struct, signal, tempfile
10
11from threading import Thread
12from typing import Any, Dict
13
14import lldb
15from lldb.plugins.scripted_process import PassthroughScriptedProcess
16from lldb.plugins.scripted_process import PassthroughScriptedThread
17
18
19class MultiplexedScriptedProcess(PassthroughScriptedProcess):
20    def __init__(self, exe_ctx: lldb.SBExecutionContext, args: lldb.SBStructuredData):
21        super().__init__(exe_ctx, args)
22        self.multiplexer = None
23        if isinstance(self.driving_process, lldb.SBProcess) and self.driving_process:
24            parity = args.GetValueForKey("parity")
25            # TODO: Change to Walrus operator (:=) with oneline if assignment
26            # Requires python 3.8
27            val = parity.GetUnsignedIntegerValue()
28            if val is not None:
29                self.parity = val
30
31            # Turn PassthroughScriptedThread into MultiplexedScriptedThread
32            for thread in self.threads.values():
33                thread.__class__ = MultiplexedScriptedThread
34
35    def get_process_id(self) -> int:
36        return self.parity + 420
37
38    def launch(self, should_stop: bool = True) -> lldb.SBError:
39        self.first_launch = True
40        return lldb.SBError()
41
42    def resume(self, should_stop: bool) -> lldb.SBError:
43        if self.first_launch:
44            self.first_launch = False
45            return super().resume()
46        else:
47            if not self.multiplexer:
48                error = lldb.SBError("Multiplexer is not set.")
49                return error
50            return self.multiplexer.resume(should_stop)
51
52    def get_threads_info(self) -> Dict[int, Any]:
53        if not self.multiplexer:
54            return super().get_threads_info()
55        filtered_threads = self.multiplexer.get_threads_info(pid=self.get_process_id())
56        # Update the filtered thread class from PassthroughScriptedThread to MultiplexedScriptedThread
57        return dict(
58            map(
59                lambda pair: (pair[0], MultiplexedScriptedThread(pair[1])),
60                filtered_threads.items(),
61            )
62        )
63
64    def create_breakpoint(self, addr, error, pid=None):
65        if not self.multiplexer:
66            error.SetErrorString("Multiplexer is not set.")
67        return self.multiplexer.create_breakpoint(addr, error, self.get_process_id())
68
69    def get_scripted_thread_plugin(self) -> str:
70        return f"{MultiplexedScriptedThread.__module__}.{MultiplexedScriptedThread.__name__}"
71
72
73class MultiplexedScriptedThread(PassthroughScriptedThread):
74    def get_name(self) -> str:
75        parity = "Odd" if self.scripted_process.parity % 2 else "Even"
76        return f"{parity}{MultiplexedScriptedThread.__name__}.thread-{self.idx}"
77
78
79class MultiplexerScriptedProcess(PassthroughScriptedProcess):
80    listener = None
81    multiplexed_processes = None
82
83    def wait_for_driving_process_to_stop(self):
84        def handle_process_state_event():
85            # Update multiplexer process
86            log("Updating interactive scripted process threads")
87            dbg = self.driving_target.GetDebugger()
88            new_driving_thread_ids = []
89            for driving_thread in self.driving_process:
90                new_driving_thread_ids.append(driving_thread.id)
91                log(f"{len(self.threads)} New thread {hex(driving_thread.id)}")
92                structured_data = lldb.SBStructuredData()
93                structured_data.SetFromJSON(
94                    json.dumps(
95                        {
96                            "driving_target_idx": dbg.GetIndexOfTarget(
97                                self.driving_target
98                            ),
99                            "thread_idx": driving_thread.GetIndexID(),
100                        }
101                    )
102                )
103
104                self.threads[driving_thread.id] = PassthroughScriptedThread(
105                    self, structured_data
106                )
107
108            for thread_id in self.threads:
109                if thread_id not in new_driving_thread_ids:
110                    log(f"Removing old thread {hex(thread_id)}")
111                    del self.threads[thread_id]
112
113            print(f"New thread count: {len(self.threads)}")
114
115            mux_process = self.target.GetProcess()
116            mux_process.ForceScriptedState(lldb.eStateRunning)
117            mux_process.ForceScriptedState(lldb.eStateStopped)
118
119            for child_process in self.multiplexed_processes.values():
120                child_process.ForceScriptedState(lldb.eStateRunning)
121                child_process.ForceScriptedState(lldb.eStateStopped)
122
123        event = lldb.SBEvent()
124        while True:
125            if not self.driving_process:
126                continue
127            if self.listener.WaitForEvent(1, event):
128                event_mask = event.GetType()
129                if event_mask & lldb.SBProcess.eBroadcastBitStateChanged:
130                    state = lldb.SBProcess.GetStateFromEvent(event)
131                    log(f"Received public process state event: {state}")
132                    if state == lldb.eStateStopped:
133                        # If it's a stop event, iterate over the driving process
134                        # thread, looking for a breakpoint stop reason, if internal
135                        # continue.
136                        handle_process_state_event()
137            else:
138                continue
139
140    def __init__(self, exe_ctx: lldb.SBExecutionContext, args: lldb.SBStructuredData):
141        super().__init__(exe_ctx, args, launched_driving_process=False)
142        if isinstance(self.driving_target, lldb.SBTarget) and self.driving_target:
143            self.listener = lldb.SBListener(
144                "lldb.listener.multiplexer-scripted-process"
145            )
146            self.multiplexed_processes = {}
147
148            # Copy breakpoints from real target to passthrough
149            with tempfile.NamedTemporaryFile() as tf:
150                bkpt_file = lldb.SBFileSpec(tf.name)
151                error = self.driving_target.BreakpointsWriteToFile(bkpt_file)
152                if error.Fail():
153                    log(
154                        "Failed to save breakpoints from driving target (%s)"
155                        % error.GetCString()
156                    )
157                bkpts_list = lldb.SBBreakpointList(self.target)
158                error = self.target.BreakpointsCreateFromFile(bkpt_file, bkpts_list)
159                if error.Fail():
160                    log(
161                        "Failed create breakpoints from driving target \
162                        (bkpt file: %s)"
163                        % tf.name
164                    )
165
166            # Copy breakpoint from passthrough to real target
167            if error.Success():
168                self.driving_target.DeleteAllBreakpoints()
169                for bkpt in self.target.breakpoints:
170                    if bkpt.IsValid():
171                        for bl in bkpt:
172                            real_bpkt = self.driving_target.BreakpointCreateBySBAddress(
173                                bl.GetAddress()
174                            )
175                            if not real_bpkt.IsValid():
176                                log(
177                                    "Failed to set breakpoint at address %s in \
178                                    driving target"
179                                    % hex(bl.GetLoadAddress())
180                                )
181
182            self.listener_thread = Thread(
183                target=self.wait_for_driving_process_to_stop, daemon=True
184            )
185            self.listener_thread.start()
186
187    def launch(self, should_stop: bool = True) -> lldb.SBError:
188        if not self.driving_target:
189            return lldb.SBError(
190                f"{self.__class__.__name__}.resume: Invalid driving target."
191            )
192
193        if self.driving_process:
194            return lldb.SBError(
195                f"{self.__class__.__name__}.resume: Invalid driving process."
196            )
197
198        error = lldb.SBError()
199        launch_info = lldb.SBLaunchInfo(None)
200        launch_info.SetListener(self.listener)
201        driving_process = self.driving_target.Launch(launch_info, error)
202
203        if not driving_process or error.Fail():
204            return error
205
206        self.driving_process = driving_process
207
208        for module in self.driving_target.modules:
209            path = module.file.fullpath
210            load_addr = module.GetObjectFileHeaderAddress().GetLoadAddress(
211                self.driving_target
212            )
213            self.loaded_images.append({"path": path, "load_addr": load_addr})
214
215        self.first_resume = True
216        return error
217
218    def resume(self, should_stop: bool = True) -> lldb.SBError:
219        if self.first_resume:
220            # When we resume the multiplexer process for the first time,
221            # we shouldn't do anything because lldb's execution machinery
222            # will resume the driving process by itself.
223
224            # Also, no need to update the multiplexer scripted process state
225            # here because since it's listening for the real process stop events.
226            # Once it receives the stop event from the driving process,
227            # `wait_for_driving_process_to_stop` will update the multiplexer
228            # state for us.
229
230            self.first_resume = False
231            return lldb.SBError()
232
233        if not self.driving_process:
234            return lldb.SBError(
235                f"{self.__class__.__name__}.resume: Invalid driving process."
236            )
237
238        return self.driving_process.Continue()
239
240    def get_threads_info(self, pid: int = None) -> Dict[int, Any]:
241        if not pid:
242            return super().get_threads_info()
243        parity = pid % 2
244        return dict(filter(lambda pair: pair[0] % 2 == parity, self.threads.items()))
245
246    def create_breakpoint(self, addr, error, pid=None):
247        if not self.driving_target:
248            error.SetErrorString("%s has no driving target." % self.__class__.__name__)
249            return False
250
251        def create_breakpoint_with_name(target, load_addr, name, error):
252            addr = lldb.SBAddress(load_addr, target)
253            if not addr.IsValid():
254                error.SetErrorString("Invalid breakpoint address %s" % hex(load_addr))
255                return False
256            bkpt = target.BreakpointCreateBySBAddress(addr)
257            if not bkpt.IsValid():
258                error.SetErrorString(
259                    "Failed to create breakpoint at address %s"
260                    % hex(addr.GetLoadAddress())
261                )
262                return False
263            error = bkpt.AddNameWithErrorHandling(name)
264            return error.Success()
265
266        name = (
267            "multiplexer_scripted_process"
268            if not pid
269            else f"multiplexed_scripted_process_{pid}"
270        )
271
272        if pid is not None:
273            # This means that this method has been called from one of the
274            # multiplexed scripted process. That also means that the multiplexer
275            # target doesn't have this breakpoint created.
276            mux_error = lldb.SBError()
277            bkpt = create_breakpoint_with_name(self.target, addr, name, mux_error)
278            if mux_error.Fail():
279                error.SetError(
280                    "Failed to create breakpoint in multiplexer \
281                               target: %s"
282                    % mux_error.GetCString()
283                )
284                return False
285        return create_breakpoint_with_name(self.driving_target, addr, name, error)
286
287
288def multiplex(mux_process, muxed_process):
289    muxed_process.GetScriptedImplementation().multiplexer = (
290        mux_process.GetScriptedImplementation()
291    )
292    mux_process.GetScriptedImplementation().multiplexed_processes[
293        muxed_process.GetProcessID()
294    ] = muxed_process
295
296
297def launch_scripted_process(target, class_name, dictionary):
298    structured_data = lldb.SBStructuredData()
299    structured_data.SetFromJSON(json.dumps(dictionary))
300
301    launch_info = lldb.SBLaunchInfo(None)
302    launch_info.SetProcessPluginName("ScriptedProcess")
303    launch_info.SetScriptedProcessClassName(class_name)
304    launch_info.SetScriptedProcessDictionary(structured_data)
305
306    error = lldb.SBError()
307    return target.Launch(launch_info, error)
308
309
310def duplicate_target(driving_target):
311    error = lldb.SBError()
312    exe = driving_target.executable.fullpath
313    triple = driving_target.triple
314    debugger = driving_target.GetDebugger()
315    return debugger.CreateTargetWithFileAndTargetTriple(exe, triple)
316
317
318def create_mux_process(debugger, command, exe_ctx, result, dict):
319    if not debugger.GetNumTargets() > 0:
320        return result.SetError(
321            "Interactive scripted processes requires one non scripted process."
322        )
323
324    debugger.SetAsync(True)
325
326    driving_target = debugger.GetSelectedTarget()
327    if not driving_target:
328        return result.SetError("Driving target is invalid")
329
330    # Create a seconde target for the multiplexer scripted process
331    mux_target = duplicate_target(driving_target)
332    if not mux_target:
333        return result.SetError(
334            "Couldn't duplicate driving target to launch multiplexer scripted process"
335        )
336
337    class_name = f"{__name__}.{MultiplexerScriptedProcess.__name__}"
338    dictionary = {"driving_target_idx": debugger.GetIndexOfTarget(driving_target)}
339    mux_process = launch_scripted_process(mux_target, class_name, dictionary)
340    if not mux_process:
341        return result.SetError("Couldn't launch multiplexer scripted process")
342
343
344def create_child_processes(debugger, command, exe_ctx, result, dict):
345    if not debugger.GetNumTargets() >= 2:
346        return result.SetError("Scripted Multiplexer process not setup")
347
348    debugger.SetAsync(True)
349
350    # Create a seconde target for the multiplexer scripted process
351    mux_target = debugger.GetSelectedTarget()
352    if not mux_target:
353        return result.SetError("Couldn't get multiplexer scripted process target")
354    mux_process = mux_target.GetProcess()
355    if not mux_process:
356        return result.SetError("Couldn't get multiplexer scripted process")
357
358    driving_target = mux_process.GetScriptedImplementation().driving_target
359    if not driving_target:
360        return result.SetError("Driving target is invalid")
361
362    # Create a target for the multiplexed even scripted process
363    even_target = duplicate_target(driving_target)
364    if not even_target:
365        return result.SetError(
366            "Couldn't duplicate driving target to launch multiplexed even scripted process"
367        )
368
369    class_name = f"{__name__}.{MultiplexedScriptedProcess.__name__}"
370    dictionary = {"driving_target_idx": debugger.GetIndexOfTarget(mux_target)}
371    dictionary["parity"] = 0
372    even_process = launch_scripted_process(even_target, class_name, dictionary)
373    if not even_process:
374        return result.SetError("Couldn't launch multiplexed even scripted process")
375    multiplex(mux_process, even_process)
376
377    # Create a target for the multiplexed odd scripted process
378    odd_target = duplicate_target(driving_target)
379    if not odd_target:
380        return result.SetError(
381            "Couldn't duplicate driving target to launch multiplexed odd scripted process"
382        )
383
384    dictionary["parity"] = 1
385    odd_process = launch_scripted_process(odd_target, class_name, dictionary)
386    if not odd_process:
387        return result.SetError("Couldn't launch multiplexed odd scripted process")
388    multiplex(mux_process, odd_process)
389
390
391def log(message):
392    # FIXME: For now, we discard the log message until we can pass it to an lldb
393    # logging channel.
394    should_log = False
395    if should_log:
396        print(message)
397
398
399def __lldb_init_module(dbg, dict):
400    dbg.HandleCommand(
401        "command script add -o -f interactive_scripted_process.create_mux_process create_mux"
402    )
403    dbg.HandleCommand(
404        "command script add -o -f interactive_scripted_process.create_child_processes create_sub"
405    )
406