1# Usage:
2# ./bin/lldb $LLVM/lldb/test/API/functionalities/interactive_scripted_process/main \
3#   -o "br set -p 'Break here'" \
4#   -o "command script import $LLVM/lldb/test/API/functionalities/interactive_scripted_process/interactive_scripted_process.py" \
5#   -o "create_mux" \
6#   -o "create_sub" \
7#   -o "br set -p 'also break here'" -o 'continue'
8
9import os, json, struct, signal, tempfile
10
11from threading import Thread
12from typing import Any, Dict
13
14import lldb
15from lldb.plugins.scripted_process import ScriptedProcess
16from lldb.plugins.scripted_process import ScriptedThread
17
18
19class PassthruScriptedProcess(ScriptedProcess):
20    driving_target = None
21    driving_process = None
22
23    def __init__(
24        self,
25        exe_ctx: lldb.SBExecutionContext,
26        args: lldb.SBStructuredData,
27        launched_driving_process: bool = True,
28    ):
29        super().__init__(exe_ctx, args)
30
31        self.driving_target = None
32        self.driving_process = None
33
34        self.driving_target_idx = args.GetValueForKey("driving_target_idx")
35        if self.driving_target_idx and self.driving_target_idx.IsValid():
36            if self.driving_target_idx.GetType() == lldb.eStructuredDataTypeInteger:
37                idx = self.driving_target_idx.GetIntegerValue(42)
38            if self.driving_target_idx.GetType() == lldb.eStructuredDataTypeString:
39                idx = int(self.driving_target_idx.GetStringValue(100))
40            self.driving_target = self.target.GetDebugger().GetTargetAtIndex(idx)
41
42            if launched_driving_process:
43                self.driving_process = self.driving_target.GetProcess()
44                for driving_thread in self.driving_process:
45                    structured_data = lldb.SBStructuredData()
46                    structured_data.SetFromJSON(
47                        json.dumps(
48                            {
49                                "driving_target_idx": idx,
50                                "thread_idx": driving_thread.GetIndexID(),
51                            }
52                        )
53                    )
54
55                    self.threads[driving_thread.GetThreadID()] = PassthruScriptedThread(
56                        self, structured_data
57                    )
58
59                for module in self.driving_target.modules:
60                    path = module.file.fullpath
61                    load_addr = module.GetObjectFileHeaderAddress().GetLoadAddress(
62                        self.driving_target
63                    )
64                    self.loaded_images.append({"path": path, "load_addr": load_addr})
65
66    def get_memory_region_containing_address(
67        self, addr: int
68    ) -> lldb.SBMemoryRegionInfo:
69        mem_region = lldb.SBMemoryRegionInfo()
70        error = self.driving_process.GetMemoryRegionInfo(addr, mem_region)
71        if error.Fail():
72            return None
73        return mem_region
74
75    def read_memory_at_address(
76        self, addr: int, size: int, error: lldb.SBError
77    ) -> lldb.SBData:
78        data = lldb.SBData()
79        bytes_read = self.driving_process.ReadMemory(addr, size, error)
80
81        if error.Fail():
82            return data
83
84        data.SetDataWithOwnership(
85            error,
86            bytes_read,
87            self.driving_target.GetByteOrder(),
88            self.driving_target.GetAddressByteSize(),
89        )
90
91        return data
92
93    def write_memory_at_address(
94        self, addr: int, data: lldb.SBData, error: lldb.SBError
95    ) -> int:
96        return self.driving_process.WriteMemory(
97            addr, bytearray(data.uint8.all()), error
98        )
99
100    def get_process_id(self) -> int:
101        return 42
102
103    def is_alive(self) -> bool:
104        return True
105
106    def get_scripted_thread_plugin(self) -> str:
107        return f"{PassthruScriptedThread.__module__}.{PassthruScriptedThread.__name__}"
108
109
110class MultiplexedScriptedProcess(PassthruScriptedProcess):
111    def __init__(self, exe_ctx: lldb.SBExecutionContext, args: lldb.SBStructuredData):
112        super().__init__(exe_ctx, args)
113        self.multiplexer = None
114        if isinstance(self.driving_process, lldb.SBProcess) and self.driving_process:
115            parity = args.GetValueForKey("parity")
116            # TODO: Change to Walrus operator (:=) with oneline if assignment
117            # Requires python 3.8
118            val = extract_value_from_structured_data(parity, 0)
119            if val is not None:
120                self.parity = val
121
122            # Turn PassThruScriptedThread into MultiplexedScriptedThread
123            for thread in self.threads.values():
124                thread.__class__ = MultiplexedScriptedThread
125
126    def get_process_id(self) -> int:
127        return self.parity + 420
128
129    def launch(self, should_stop: bool = True) -> lldb.SBError:
130        self.first_launch = True
131        return lldb.SBError()
132
133    def resume(self, should_stop: bool) -> lldb.SBError:
134        if self.first_launch:
135            self.first_launch = False
136            return super().resume()
137        else:
138            if not self.multiplexer:
139                error = lldb.SBError("Multiplexer is not set.")
140                return error
141            return self.multiplexer.resume(should_stop)
142
143    def get_threads_info(self) -> Dict[int, Any]:
144        if not self.multiplexer:
145            return super().get_threads_info()
146        filtered_threads = self.multiplexer.get_threads_info(pid=self.get_process_id())
147        # Update the filtered thread class from PassthruScriptedThread to MultiplexedScriptedThread
148        return dict(
149            map(
150                lambda pair: (pair[0], MultiplexedScriptedThread(pair[1])),
151                filtered_threads.items(),
152            )
153        )
154
155    def create_breakpoint(self, addr, error, pid=None):
156        if not self.multiplexer:
157            error.SetErrorString("Multiplexer is not set.")
158        return self.multiplexer.create_breakpoint(addr, error, self.get_process_id())
159
160    def get_scripted_thread_plugin(self) -> str:
161        return f"{MultiplexedScriptedThread.__module__}.{MultiplexedScriptedThread.__name__}"
162
163
164class PassthruScriptedThread(ScriptedThread):
165    def __init__(self, process, args):
166        super().__init__(process, args)
167        driving_target_idx = args.GetValueForKey("driving_target_idx")
168        thread_idx = args.GetValueForKey("thread_idx")
169
170        # TODO: Change to Walrus operator (:=) with oneline if assignment
171        # Requires python 3.8
172        val = extract_value_from_structured_data(thread_idx, 0)
173        if val is not None:
174            self.idx = val
175
176        self.driving_target = None
177        self.driving_process = None
178        self.driving_thread = None
179
180        # TODO: Change to Walrus operator (:=) with oneline if assignment
181        # Requires python 3.8
182        val = extract_value_from_structured_data(driving_target_idx, 42)
183        if val is not None:
184            self.driving_target = self.target.GetDebugger().GetTargetAtIndex(val)
185            self.driving_process = self.driving_target.GetProcess()
186            self.driving_thread = self.driving_process.GetThreadByIndexID(self.idx)
187
188        if self.driving_thread:
189            self.id = self.driving_thread.GetThreadID()
190
191    def get_thread_id(self) -> int:
192        return self.id
193
194    def get_name(self) -> str:
195        return f"{PassthruScriptedThread.__name__}.thread-{self.idx}"
196
197    def get_stop_reason(self) -> Dict[str, Any]:
198        stop_reason = {"type": lldb.eStopReasonInvalid, "data": {}}
199
200        if (
201            self.driving_thread
202            and self.driving_thread.IsValid()
203            and self.get_thread_id() == self.driving_thread.GetThreadID()
204        ):
205            stop_reason["type"] = lldb.eStopReasonNone
206
207            if self.driving_thread.GetStopReason() != lldb.eStopReasonNone:
208                if "arm64" in self.scripted_process.arch:
209                    stop_reason["type"] = lldb.eStopReasonException
210                    stop_reason["data"][
211                        "desc"
212                    ] = self.driving_thread.GetStopDescription(100)
213                elif self.scripted_process.arch == "x86_64":
214                    stop_reason["type"] = lldb.eStopReasonSignal
215                    stop_reason["data"]["signal"] = signal.SIGTRAP
216                else:
217                    stop_reason["type"] = self.driving_thread.GetStopReason()
218
219        return stop_reason
220
221    def get_register_context(self) -> str:
222        if not self.driving_thread or self.driving_thread.GetNumFrames() == 0:
223            return None
224        frame = self.driving_thread.GetFrameAtIndex(0)
225
226        GPRs = None
227        registerSet = frame.registers  # Returns an SBValueList.
228        for regs in registerSet:
229            if "general purpose" in regs.name.lower():
230                GPRs = regs
231                break
232
233        if not GPRs:
234            return None
235
236        for reg in GPRs:
237            self.register_ctx[reg.name] = int(reg.value, base=16)
238
239        return struct.pack(f"{len(self.register_ctx)}Q", *self.register_ctx.values())
240
241
242class MultiplexedScriptedThread(PassthruScriptedThread):
243    def get_name(self) -> str:
244        parity = "Odd" if self.scripted_process.parity % 2 else "Even"
245        return f"{parity}{MultiplexedScriptedThread.__name__}.thread-{self.idx}"
246
247
248class MultiplexerScriptedProcess(PassthruScriptedProcess):
249    listener = None
250    multiplexed_processes = None
251
252    def wait_for_driving_process_to_stop(self):
253        def handle_process_state_event():
254            # Update multiplexer process
255            log("Updating interactive scripted process threads")
256            dbg = self.driving_target.GetDebugger()
257            log("Clearing interactive scripted process threads")
258            self.threads.clear()
259            for driving_thread in self.driving_process:
260                log(f"{len(self.threads)} New thread {hex(driving_thread.id)}")
261                structured_data = lldb.SBStructuredData()
262                structured_data.SetFromJSON(
263                    json.dumps(
264                        {
265                            "driving_target_idx": dbg.GetIndexOfTarget(
266                                self.driving_target
267                            ),
268                            "thread_idx": driving_thread.GetIndexID(),
269                        }
270                    )
271                )
272
273                self.threads[driving_thread.GetThreadID()] = PassthruScriptedThread(
274                    self, structured_data
275                )
276
277            mux_process = self.target.GetProcess()
278            mux_process.ForceScriptedState(lldb.eStateRunning)
279            mux_process.ForceScriptedState(lldb.eStateStopped)
280
281            for child_process in self.multiplexed_processes.values():
282                child_process.ForceScriptedState(lldb.eStateRunning)
283                child_process.ForceScriptedState(lldb.eStateStopped)
284
285        event = lldb.SBEvent()
286        while True:
287            if self.listener.WaitForEvent(1, event):
288                event_mask = event.GetType()
289                if event_mask & lldb.SBProcess.eBroadcastBitStateChanged:
290                    state = lldb.SBProcess.GetStateFromEvent(event)
291                    log(f"Received public process state event: {state}")
292                    if state == lldb.eStateStopped:
293                        # If it's a stop event, iterate over the driving process
294                        # thread, looking for a breakpoint stop reason, if internal
295                        # continue.
296                        handle_process_state_event()
297            else:
298                continue
299
300    def __init__(self, exe_ctx: lldb.SBExecutionContext, args: lldb.SBStructuredData):
301        super().__init__(exe_ctx, args, launched_driving_process=False)
302        if isinstance(self.driving_target, lldb.SBTarget) and self.driving_target:
303            self.listener = lldb.SBListener(
304                "lldb.listener.multiplexer-scripted-process"
305            )
306            self.multiplexed_processes = {}
307
308            # Copy breakpoints from real target to passthrough
309            with tempfile.NamedTemporaryFile() as tf:
310                bkpt_file = lldb.SBFileSpec(tf.name)
311                error = self.driving_target.BreakpointsWriteToFile(bkpt_file)
312                if error.Fail():
313                    log(
314                        "Failed to save breakpoints from driving target (%s)"
315                        % error.GetCString()
316                    )
317                bkpts_list = lldb.SBBreakpointList(self.target)
318                error = self.target.BreakpointsCreateFromFile(bkpt_file, bkpts_list)
319                if error.Fail():
320                    log(
321                        "Failed create breakpoints from driving target \
322                        (bkpt file: %s)"
323                        % tf.name
324                    )
325
326            # Copy breakpoint from passthrough to real target
327            if error.Success():
328                self.driving_target.DeleteAllBreakpoints()
329                for bkpt in self.target.breakpoints:
330                    if bkpt.IsValid():
331                        for bl in bkpt:
332                            real_bpkt = self.driving_target.BreakpointCreateBySBAddress(
333                                bl.GetAddress()
334                            )
335                            if not real_bpkt.IsValid():
336                                log(
337                                    "Failed to set breakpoint at address %s in \
338                                    driving target"
339                                    % hex(bl.GetLoadAddress())
340                                )
341
342            self.listener_thread = Thread(
343                target=self.wait_for_driving_process_to_stop, daemon=True
344            )
345            self.listener_thread.start()
346
347    def launch(self, should_stop: bool = True) -> lldb.SBError:
348        if not self.driving_target:
349            return lldb.SBError(
350                f"{self.__class__.__name__}.resume: Invalid driving target."
351            )
352
353        if self.driving_process:
354            return lldb.SBError(
355                f"{self.__class__.__name__}.resume: Invalid driving process."
356            )
357
358        error = lldb.SBError()
359        launch_info = lldb.SBLaunchInfo(None)
360        launch_info.SetListener(self.listener)
361        driving_process = self.driving_target.Launch(launch_info, error)
362
363        if not driving_process or error.Fail():
364            return error
365
366        self.driving_process = driving_process
367
368        for module in self.driving_target.modules:
369            path = module.file.fullpath
370            load_addr = module.GetObjectFileHeaderAddress().GetLoadAddress(
371                self.driving_target
372            )
373            self.loaded_images.append({"path": path, "load_addr": load_addr})
374
375        self.first_resume = True
376        return error
377
378    def resume(self, should_stop: bool = True) -> lldb.SBError:
379        if self.first_resume:
380            # When we resume the multiplexer process for the first time,
381            # we shouldn't do anything because lldb's execution machinery
382            # will resume the driving process by itself.
383
384            # Also, no need to update the multiplexer scripted process state
385            # here because since it's listening for the real process stop events.
386            # Once it receives the stop event from the driving process,
387            # `wait_for_driving_process_to_stop` will update the multiplexer
388            # state for us.
389
390            self.first_resume = False
391            return lldb.SBError()
392
393        if not self.driving_process:
394            return lldb.SBError(
395                f"{self.__class__.__name__}.resume: Invalid driving process."
396            )
397
398        return self.driving_process.Continue()
399
400    def get_threads_info(self, pid: int = None) -> Dict[int, Any]:
401        if not pid:
402            return super().get_threads_info()
403        parity = pid % 2
404        return dict(filter(lambda pair: pair[0] % 2 == parity, self.threads.items()))
405
406    def create_breakpoint(self, addr, error, pid=None):
407        if not self.driving_target:
408            error.SetErrorString("%s has no driving target." % self.__class__.__name__)
409            return False
410
411        def create_breakpoint_with_name(target, load_addr, name, error):
412            addr = lldb.SBAddress(load_addr, target)
413            if not addr.IsValid():
414                error.SetErrorString("Invalid breakpoint address %s" % hex(load_addr))
415                return False
416            bkpt = target.BreakpointCreateBySBAddress(addr)
417            if not bkpt.IsValid():
418                error.SetErrorString(
419                    "Failed to create breakpoint at address %s"
420                    % hex(addr.GetLoadAddress())
421                )
422                return False
423            error = bkpt.AddNameWithErrorHandling(name)
424            return error.Success()
425
426        name = (
427            "multiplexer_scripted_process"
428            if not pid
429            else f"multiplexed_scripted_process_{pid}"
430        )
431
432        if pid is not None:
433            # This means that this method has been called from one of the
434            # multiplexed scripted process. That also means that the multiplexer
435            # target doesn't have this breakpoint created.
436            mux_error = lldb.SBError()
437            bkpt = create_breakpoint_with_name(self.target, addr, name, mux_error)
438            if mux_error.Fail():
439                error.SetError(
440                    "Failed to create breakpoint in multiplexer \
441                               target: %s"
442                    % mux_error.GetCString()
443                )
444                return False
445        return create_breakpoint_with_name(self.driving_target, addr, name, error)
446
447
448def multiplex(mux_process, muxed_process):
449    muxed_process.GetScriptedImplementation().multiplexer = (
450        mux_process.GetScriptedImplementation()
451    )
452    mux_process.GetScriptedImplementation().multiplexed_processes[
453        muxed_process.GetProcessID()
454    ] = muxed_process
455
456
457def launch_scripted_process(target, class_name, dictionary):
458    structured_data = lldb.SBStructuredData()
459    structured_data.SetFromJSON(json.dumps(dictionary))
460
461    launch_info = lldb.SBLaunchInfo(None)
462    launch_info.SetProcessPluginName("ScriptedProcess")
463    launch_info.SetScriptedProcessClassName(class_name)
464    launch_info.SetScriptedProcessDictionary(structured_data)
465
466    error = lldb.SBError()
467    return target.Launch(launch_info, error)
468
469
470def duplicate_target(driving_target):
471    error = lldb.SBError()
472    exe = driving_target.executable.fullpath
473    triple = driving_target.triple
474    debugger = driving_target.GetDebugger()
475    return debugger.CreateTargetWithFileAndTargetTriple(exe, triple)
476
477
478def extract_value_from_structured_data(data, default_val):
479    if data and data.IsValid():
480        if data.GetType() == lldb.eStructuredDataTypeInteger:
481            return data.GetIntegerValue(default_val)
482        if data.GetType() == lldb.eStructuredDataTypeString:
483            return int(data.GetStringValue(100))
484    return default_val
485
486
487def create_mux_process(debugger, command, exe_ctx, result, dict):
488    if not debugger.GetNumTargets() > 0:
489        return result.SetError(
490            "Interactive scripted processes requires one non scripted process."
491        )
492
493    debugger.SetAsync(True)
494
495    driving_target = debugger.GetSelectedTarget()
496    if not driving_target:
497        return result.SetError("Driving target is invalid")
498
499    # Create a seconde target for the multiplexer scripted process
500    mux_target = duplicate_target(driving_target)
501    if not mux_target:
502        return result.SetError(
503            "Couldn't duplicate driving target to launch multiplexer scripted process"
504        )
505
506    class_name = f"{__name__}.{MultiplexerScriptedProcess.__name__}"
507    dictionary = {"driving_target_idx": debugger.GetIndexOfTarget(driving_target)}
508    mux_process = launch_scripted_process(mux_target, class_name, dictionary)
509    if not mux_process:
510        return result.SetError("Couldn't launch multiplexer scripted process")
511
512
513def create_child_processes(debugger, command, exe_ctx, result, dict):
514    if not debugger.GetNumTargets() >= 2:
515        return result.SetError("Scripted Multiplexer process not setup")
516
517    debugger.SetAsync(True)
518
519    # Create a seconde target for the multiplexer scripted process
520    mux_target = debugger.GetSelectedTarget()
521    if not mux_target:
522        return result.SetError("Couldn't get multiplexer scripted process target")
523    mux_process = mux_target.GetProcess()
524    if not mux_process:
525        return result.SetError("Couldn't get multiplexer scripted process")
526
527    driving_target = mux_process.GetScriptedImplementation().driving_target
528    if not driving_target:
529        return result.SetError("Driving target is invalid")
530
531    # Create a target for the multiplexed even scripted process
532    even_target = duplicate_target(driving_target)
533    if not even_target:
534        return result.SetError(
535            "Couldn't duplicate driving target to launch multiplexed even scripted process"
536        )
537
538    class_name = f"{__name__}.{MultiplexedScriptedProcess.__name__}"
539    dictionary = {"driving_target_idx": debugger.GetIndexOfTarget(mux_target)}
540    dictionary["parity"] = 0
541    even_process = launch_scripted_process(even_target, class_name, dictionary)
542    if not even_process:
543        return result.SetError("Couldn't launch multiplexed even scripted process")
544    multiplex(mux_process, even_process)
545
546    # Create a target for the multiplexed odd scripted process
547    odd_target = duplicate_target(driving_target)
548    if not odd_target:
549        return result.SetError(
550            "Couldn't duplicate driving target to launch multiplexed odd scripted process"
551        )
552
553    dictionary["parity"] = 1
554    odd_process = launch_scripted_process(odd_target, class_name, dictionary)
555    if not odd_process:
556        return result.SetError("Couldn't launch multiplexed odd scripted process")
557    multiplex(mux_process, odd_process)
558
559
560def log(message):
561    # FIXME: For now, we discard the log message until we can pass it to an lldb
562    # logging channel.
563    should_log = False
564    if should_log:
565        print(message)
566
567
568def __lldb_init_module(dbg, dict):
569    dbg.HandleCommand(
570        "command script add -o -f interactive_scripted_process.create_mux_process create_mux"
571    )
572    dbg.HandleCommand(
573        "command script add -o -f interactive_scripted_process.create_child_processes create_sub"
574    )
575