1# Usage:
2# ./bin/lldb $LLVM/lldb/test/API/functionalities/interactive_scripted_process/main \
3#   -o "br set -p 'Break here'" \
4#   -o "command script import $LLVM/lldb/test/API/functionalities/interactive_scripted_process/interactive_scripted_process.py" \
5#   -o "create_mux" \
6#   -o "create_sub" \
7#   -o "br set -p 'also break here'" -o 'continue'
8
9import os, json, struct, signal, tempfile
10
11from threading import Thread
12from typing import Any, Dict
13
14import lldb
15from lldb.plugins.scripted_process import PassthroughScriptedProcess
16from lldb.plugins.scripted_process import PassthroughScriptedThread
17
18class MultiplexedScriptedProcess(PassthroughScriptedProcess):
19    def __init__(self, exe_ctx: lldb.SBExecutionContext, args: lldb.SBStructuredData):
20        super().__init__(exe_ctx, args)
21        self.multiplexer = None
22        if isinstance(self.driving_process, lldb.SBProcess) and self.driving_process:
23            parity = args.GetValueForKey("parity")
24            # TODO: Change to Walrus operator (:=) with oneline if assignment
25            # Requires python 3.8
26            val = parity.GetUnsignedIntegerValue()
27            if val is not None:
28                self.parity = val
29
30            # Turn PassthroughScriptedThread into MultiplexedScriptedThread
31            for thread in self.threads.values():
32                thread.__class__ = MultiplexedScriptedThread
33
34    def get_process_id(self) -> int:
35        return self.parity + 420
36
37    def launch(self, should_stop: bool = True) -> lldb.SBError:
38        self.first_launch = True
39        return lldb.SBError()
40
41    def resume(self, should_stop: bool) -> lldb.SBError:
42        if self.first_launch:
43            self.first_launch = False
44            return super().resume()
45        else:
46            if not self.multiplexer:
47                error = lldb.SBError("Multiplexer is not set.")
48                return error
49            return self.multiplexer.resume(should_stop)
50
51    def get_threads_info(self) -> Dict[int, Any]:
52        if not self.multiplexer:
53            return super().get_threads_info()
54        filtered_threads = self.multiplexer.get_threads_info(pid=self.get_process_id())
55        # Update the filtered thread class from PassthroughScriptedThread to MultiplexedScriptedThread
56        return dict(
57            map(
58                lambda pair: (pair[0], MultiplexedScriptedThread(pair[1])),
59                filtered_threads.items(),
60            )
61        )
62
63    def create_breakpoint(self, addr, error, pid=None):
64        if not self.multiplexer:
65            error.SetErrorString("Multiplexer is not set.")
66        return self.multiplexer.create_breakpoint(addr, error, self.get_process_id())
67
68    def get_scripted_thread_plugin(self) -> str:
69        return f"{MultiplexedScriptedThread.__module__}.{MultiplexedScriptedThread.__name__}"
70
71class MultiplexedScriptedThread(PassthroughScriptedThread):
72    def get_name(self) -> str:
73        parity = "Odd" if self.scripted_process.parity % 2 else "Even"
74        return f"{parity}{MultiplexedScriptedThread.__name__}.thread-{self.idx}"
75
76
77class MultiplexerScriptedProcess(PassthroughScriptedProcess):
78    listener = None
79    multiplexed_processes = None
80
81    def wait_for_driving_process_to_stop(self):
82        def handle_process_state_event():
83            # Update multiplexer process
84            log("Updating interactive scripted process threads")
85            dbg = self.driving_target.GetDebugger()
86            new_driving_thread_ids = []
87            for driving_thread in self.driving_process:
88                new_driving_thread_ids.append(driving_thread.id)
89                log(f"{len(self.threads)} New thread {hex(driving_thread.id)}")
90                structured_data = lldb.SBStructuredData()
91                structured_data.SetFromJSON(
92                    json.dumps(
93                        {
94                            "driving_target_idx": dbg.GetIndexOfTarget(
95                                self.driving_target
96                            ),
97                            "thread_idx": driving_thread.GetIndexID(),
98                        }
99                    )
100                )
101
102                self.threads[driving_thread.id] = PassthroughScriptedThread(
103                    self, structured_data
104                )
105
106            for thread_id in self.threads:
107                if thread_id not in new_driving_thread_ids:
108                    log(f"Removing old thread {hex(thread_id)}")
109                    del self.threads[thread_id]
110
111            print(f"New thread count: {len(self.threads)}")
112
113            mux_process = self.target.GetProcess()
114            mux_process.ForceScriptedState(lldb.eStateRunning)
115            mux_process.ForceScriptedState(lldb.eStateStopped)
116
117            for child_process in self.multiplexed_processes.values():
118                child_process.ForceScriptedState(lldb.eStateRunning)
119                child_process.ForceScriptedState(lldb.eStateStopped)
120
121        event = lldb.SBEvent()
122        while True:
123            if not self.driving_process:
124                continue
125            if self.listener.WaitForEvent(1, event):
126                event_mask = event.GetType()
127                if event_mask & lldb.SBProcess.eBroadcastBitStateChanged:
128                    state = lldb.SBProcess.GetStateFromEvent(event)
129                    log(f"Received public process state event: {state}")
130                    if state == lldb.eStateStopped:
131                        # If it's a stop event, iterate over the driving process
132                        # thread, looking for a breakpoint stop reason, if internal
133                        # continue.
134                        handle_process_state_event()
135            else:
136                continue
137
138    def __init__(self, exe_ctx: lldb.SBExecutionContext, args: lldb.SBStructuredData):
139        super().__init__(exe_ctx, args, launched_driving_process=False)
140        if isinstance(self.driving_target, lldb.SBTarget) and self.driving_target:
141            self.listener = lldb.SBListener(
142                "lldb.listener.multiplexer-scripted-process"
143            )
144            self.multiplexed_processes = {}
145
146            # Copy breakpoints from real target to passthrough
147            with tempfile.NamedTemporaryFile() as tf:
148                bkpt_file = lldb.SBFileSpec(tf.name)
149                error = self.driving_target.BreakpointsWriteToFile(bkpt_file)
150                if error.Fail():
151                    log(
152                        "Failed to save breakpoints from driving target (%s)"
153                        % error.GetCString()
154                    )
155                bkpts_list = lldb.SBBreakpointList(self.target)
156                error = self.target.BreakpointsCreateFromFile(bkpt_file, bkpts_list)
157                if error.Fail():
158                    log(
159                        "Failed create breakpoints from driving target \
160                        (bkpt file: %s)"
161                        % tf.name
162                    )
163
164            # Copy breakpoint from passthrough to real target
165            if error.Success():
166                self.driving_target.DeleteAllBreakpoints()
167                for bkpt in self.target.breakpoints:
168                    if bkpt.IsValid():
169                        for bl in bkpt:
170                            real_bpkt = self.driving_target.BreakpointCreateBySBAddress(
171                                bl.GetAddress()
172                            )
173                            if not real_bpkt.IsValid():
174                                log(
175                                    "Failed to set breakpoint at address %s in \
176                                    driving target"
177                                    % hex(bl.GetLoadAddress())
178                                )
179
180            self.listener_thread = Thread(
181                target=self.wait_for_driving_process_to_stop, daemon=True
182            )
183            self.listener_thread.start()
184
185    def launch(self, should_stop: bool = True) -> lldb.SBError:
186        if not self.driving_target:
187            return lldb.SBError(
188                f"{self.__class__.__name__}.resume: Invalid driving target."
189            )
190
191        if self.driving_process:
192            return lldb.SBError(
193                f"{self.__class__.__name__}.resume: Invalid driving process."
194            )
195
196        error = lldb.SBError()
197        launch_info = lldb.SBLaunchInfo(None)
198        launch_info.SetListener(self.listener)
199        driving_process = self.driving_target.Launch(launch_info, error)
200
201        if not driving_process or error.Fail():
202            return error
203
204        self.driving_process = driving_process
205
206        for module in self.driving_target.modules:
207            path = module.file.fullpath
208            load_addr = module.GetObjectFileHeaderAddress().GetLoadAddress(
209                self.driving_target
210            )
211            self.loaded_images.append({"path": path, "load_addr": load_addr})
212
213        self.first_resume = True
214        return error
215
216    def resume(self, should_stop: bool = True) -> lldb.SBError:
217        if self.first_resume:
218            # When we resume the multiplexer process for the first time,
219            # we shouldn't do anything because lldb's execution machinery
220            # will resume the driving process by itself.
221
222            # Also, no need to update the multiplexer scripted process state
223            # here because since it's listening for the real process stop events.
224            # Once it receives the stop event from the driving process,
225            # `wait_for_driving_process_to_stop` will update the multiplexer
226            # state for us.
227
228            self.first_resume = False
229            return lldb.SBError()
230
231        if not self.driving_process:
232            return lldb.SBError(
233                f"{self.__class__.__name__}.resume: Invalid driving process."
234            )
235
236        return self.driving_process.Continue()
237
238    def get_threads_info(self, pid: int = None) -> Dict[int, Any]:
239        if not pid:
240            return super().get_threads_info()
241        parity = pid % 2
242        return dict(filter(lambda pair: pair[0] % 2 == parity, self.threads.items()))
243
244    def create_breakpoint(self, addr, error, pid=None):
245        if not self.driving_target:
246            error.SetErrorString("%s has no driving target." % self.__class__.__name__)
247            return False
248
249        def create_breakpoint_with_name(target, load_addr, name, error):
250            addr = lldb.SBAddress(load_addr, target)
251            if not addr.IsValid():
252                error.SetErrorString("Invalid breakpoint address %s" % hex(load_addr))
253                return False
254            bkpt = target.BreakpointCreateBySBAddress(addr)
255            if not bkpt.IsValid():
256                error.SetErrorString(
257                    "Failed to create breakpoint at address %s"
258                    % hex(addr.GetLoadAddress())
259                )
260                return False
261            error = bkpt.AddNameWithErrorHandling(name)
262            return error.Success()
263
264        name = (
265            "multiplexer_scripted_process"
266            if not pid
267            else f"multiplexed_scripted_process_{pid}"
268        )
269
270        if pid is not None:
271            # This means that this method has been called from one of the
272            # multiplexed scripted process. That also means that the multiplexer
273            # target doesn't have this breakpoint created.
274            mux_error = lldb.SBError()
275            bkpt = create_breakpoint_with_name(self.target, addr, name, mux_error)
276            if mux_error.Fail():
277                error.SetError(
278                    "Failed to create breakpoint in multiplexer \
279                               target: %s"
280                    % mux_error.GetCString()
281                )
282                return False
283        return create_breakpoint_with_name(self.driving_target, addr, name, error)
284
285
286def multiplex(mux_process, muxed_process):
287    muxed_process.GetScriptedImplementation().multiplexer = (
288        mux_process.GetScriptedImplementation()
289    )
290    mux_process.GetScriptedImplementation().multiplexed_processes[
291        muxed_process.GetProcessID()
292    ] = muxed_process
293
294
295def launch_scripted_process(target, class_name, dictionary):
296    structured_data = lldb.SBStructuredData()
297    structured_data.SetFromJSON(json.dumps(dictionary))
298
299    launch_info = lldb.SBLaunchInfo(None)
300    launch_info.SetProcessPluginName("ScriptedProcess")
301    launch_info.SetScriptedProcessClassName(class_name)
302    launch_info.SetScriptedProcessDictionary(structured_data)
303
304    error = lldb.SBError()
305    return target.Launch(launch_info, error)
306
307
308def duplicate_target(driving_target):
309    error = lldb.SBError()
310    exe = driving_target.executable.fullpath
311    triple = driving_target.triple
312    debugger = driving_target.GetDebugger()
313    return debugger.CreateTargetWithFileAndTargetTriple(exe, triple)
314
315def create_mux_process(debugger, command, exe_ctx, result, dict):
316    if not debugger.GetNumTargets() > 0:
317        return result.SetError(
318            "Interactive scripted processes requires one non scripted process."
319        )
320
321    debugger.SetAsync(True)
322
323    driving_target = debugger.GetSelectedTarget()
324    if not driving_target:
325        return result.SetError("Driving target is invalid")
326
327    # Create a seconde target for the multiplexer scripted process
328    mux_target = duplicate_target(driving_target)
329    if not mux_target:
330        return result.SetError(
331            "Couldn't duplicate driving target to launch multiplexer scripted process"
332        )
333
334    class_name = f"{__name__}.{MultiplexerScriptedProcess.__name__}"
335    dictionary = {"driving_target_idx": debugger.GetIndexOfTarget(driving_target)}
336    mux_process = launch_scripted_process(mux_target, class_name, dictionary)
337    if not mux_process:
338        return result.SetError("Couldn't launch multiplexer scripted process")
339
340
341def create_child_processes(debugger, command, exe_ctx, result, dict):
342    if not debugger.GetNumTargets() >= 2:
343        return result.SetError("Scripted Multiplexer process not setup")
344
345    debugger.SetAsync(True)
346
347    # Create a seconde target for the multiplexer scripted process
348    mux_target = debugger.GetSelectedTarget()
349    if not mux_target:
350        return result.SetError("Couldn't get multiplexer scripted process target")
351    mux_process = mux_target.GetProcess()
352    if not mux_process:
353        return result.SetError("Couldn't get multiplexer scripted process")
354
355    driving_target = mux_process.GetScriptedImplementation().driving_target
356    if not driving_target:
357        return result.SetError("Driving target is invalid")
358
359    # Create a target for the multiplexed even scripted process
360    even_target = duplicate_target(driving_target)
361    if not even_target:
362        return result.SetError(
363            "Couldn't duplicate driving target to launch multiplexed even scripted process"
364        )
365
366    class_name = f"{__name__}.{MultiplexedScriptedProcess.__name__}"
367    dictionary = {"driving_target_idx": debugger.GetIndexOfTarget(mux_target)}
368    dictionary["parity"] = 0
369    even_process = launch_scripted_process(even_target, class_name, dictionary)
370    if not even_process:
371        return result.SetError("Couldn't launch multiplexed even scripted process")
372    multiplex(mux_process, even_process)
373
374    # Create a target for the multiplexed odd scripted process
375    odd_target = duplicate_target(driving_target)
376    if not odd_target:
377        return result.SetError(
378            "Couldn't duplicate driving target to launch multiplexed odd scripted process"
379        )
380
381    dictionary["parity"] = 1
382    odd_process = launch_scripted_process(odd_target, class_name, dictionary)
383    if not odd_process:
384        return result.SetError("Couldn't launch multiplexed odd scripted process")
385    multiplex(mux_process, odd_process)
386
387
388def log(message):
389    # FIXME: For now, we discard the log message until we can pass it to an lldb
390    # logging channel.
391    should_log = False
392    if should_log:
393        print(message)
394
395
396def __lldb_init_module(dbg, dict):
397    dbg.HandleCommand(
398        "command script add -o -f interactive_scripted_process.create_mux_process create_mux"
399    )
400    dbg.HandleCommand(
401        "command script add -o -f interactive_scripted_process.create_child_processes create_sub"
402    )
403