xref: /llvm-project/cross-project-tests/debuginfo-tests/dexter/dex/heuristic/Heuristic.py (revision f98ee40f4b5d7474fc67e82824bf6abbaedb7b1c)
1# DExTer : Debugging Experience Tester
2# ~~~~~~   ~         ~~         ~   ~~
3#
4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7"""Calculate a 'score' based on some dextIR.
8Assign penalties based on different commands to decrease the score.
91.000 would be a perfect score.
100.000 is the worst theoretical score possible.
11"""
12
13from collections import defaultdict, namedtuple, Counter
14import difflib
15import os
16from itertools import groupby
17from dex.command.StepValueInfo import StepValueInfo
18from dex.command.commands.DexExpectWatchBase import format_address
19
20
21PenaltyCommand = namedtuple("PenaltyCommand", ["pen_dict", "max_penalty"])
22# 'meta' field used in different ways by different things
23PenaltyInstance = namedtuple("PenaltyInstance", ["meta", "the_penalty"])
24
25
26def add_heuristic_tool_arguments(parser):
27    parser.add_argument(
28        "--penalty-variable-optimized",
29        type=int,
30        default=3,
31        help="set the penalty multiplier for each"
32        " occurrence of a variable that was optimized"
33        " away",
34        metavar="<int>",
35    )
36    parser.add_argument(
37        "--penalty-misordered-values",
38        type=int,
39        default=3,
40        help="set the penalty multiplier for each" " occurrence of a misordered value.",
41        metavar="<int>",
42    )
43    parser.add_argument(
44        "--penalty-irretrievable",
45        type=int,
46        default=4,
47        help="set the penalty multiplier for each"
48        " occurrence of a variable that couldn't"
49        " be retrieved",
50        metavar="<int>",
51    )
52    parser.add_argument(
53        "--penalty-not-evaluatable",
54        type=int,
55        default=5,
56        help="set the penalty multiplier for each"
57        " occurrence of a variable that couldn't"
58        " be evaluated",
59        metavar="<int>",
60    )
61    parser.add_argument(
62        "--penalty-missing-values",
63        type=int,
64        default=6,
65        help="set the penalty multiplier for each missing" " value",
66        metavar="<int>",
67    )
68    parser.add_argument(
69        "--penalty-incorrect-values",
70        type=int,
71        default=7,
72        help="set the penalty multiplier for each"
73        " occurrence of an unexpected value.",
74        metavar="<int>",
75    )
76    parser.add_argument(
77        "--penalty-unreachable",
78        type=int,
79        default=4,  # XXX XXX XXX selected by random
80        help="set the penalty for each line stepped onto that should"
81        " have been unreachable.",
82        metavar="<int>",
83    )
84    parser.add_argument(
85        "--penalty-misordered-steps",
86        type=int,
87        default=2,  # XXX XXX XXX selected by random
88        help="set the penalty for differences in the order of steps"
89        " the program was expected to observe.",
90        metavar="<int>",
91    )
92    parser.add_argument(
93        "--penalty-missing-step",
94        type=int,
95        default=4,  # XXX XXX XXX selected by random
96        help="set the penalty for the program skipping over a step.",
97        metavar="<int>",
98    )
99    parser.add_argument(
100        "--penalty-incorrect-program-state",
101        type=int,
102        default=4,  # XXX XXX XXX selected by random
103        help="set the penalty for the program never entering an expected state"
104        " or entering an unexpected state.",
105        metavar="<int>",
106    )
107
108
109class PenaltyLineRanges:
110    def __init__(self, first_step, penalty):
111        self.ranges = [(first_step, first_step)]
112        self.penalty = penalty
113
114    def add_step(self, next_step, penalty):
115        last_range = self.ranges[-1]
116        last_step = last_range[1]
117        if next_step == last_step + 1:
118            self.ranges[-1] = (last_range[0], next_step)
119        else:
120            self.ranges.append((next_step, next_step))
121        self.penalty += penalty
122
123    def __str__(self):
124        range_to_str = lambda r: str(r[0]) if r[0] == r[1] else f"{r[0]}-{r[1]}"
125        if self.ranges[0][0] == self.ranges[-1][1]:
126            text = f"step {self.ranges[0][0]}"
127        else:
128            step_list = ", ".join([range_to_str(r) for r in self.ranges])
129            text = f"steps [{step_list}]"
130        if self.penalty:
131            text += " <r>[-{}]</>".format(self.penalty)
132        return text
133
134
135class Heuristic(object):
136    def __init__(self, context, steps):
137        self.context = context
138        self.penalties = {}
139        self.address_resolutions = {}
140
141        worst_penalty = max(
142            [
143                self.penalty_variable_optimized,
144                self.penalty_irretrievable,
145                self.penalty_not_evaluatable,
146                self.penalty_incorrect_values,
147                self.penalty_missing_values,
148                self.penalty_unreachable,
149                self.penalty_missing_step,
150                self.penalty_misordered_steps,
151            ]
152        )
153
154        # Before evaluating scoring commands, evaluate address values.
155        try:
156            for command in steps.commands["DexDeclareAddress"]:
157                command.address_resolutions = self.address_resolutions
158                command.eval(steps)
159        except KeyError:
160            pass
161
162        # Get DexExpectWatchType results.
163        try:
164            for command in steps.commands["DexExpectWatchType"]:
165                command.eval(steps)
166                maximum_possible_penalty = min(3, len(command.values)) * worst_penalty
167                name, p = self._calculate_expect_watch_penalties(
168                    command, maximum_possible_penalty
169                )
170                name = name + " ExpectType"
171                self.penalties[name] = PenaltyCommand(p, maximum_possible_penalty)
172        except KeyError:
173            pass
174
175        # Get DexExpectWatchValue results.
176        try:
177            for command in steps.commands["DexExpectWatchValue"]:
178                command.address_resolutions = self.address_resolutions
179                command.eval(steps)
180                maximum_possible_penalty = min(3, len(command.values)) * worst_penalty
181                name, p = self._calculate_expect_watch_penalties(
182                    command, maximum_possible_penalty
183                )
184                name = name + " ExpectValue"
185                self.penalties[name] = PenaltyCommand(p, maximum_possible_penalty)
186        except KeyError:
187            pass
188
189        try:
190            penalties = defaultdict(list)
191            maximum_possible_penalty_all = 0
192            for expect_state in steps.commands["DexExpectProgramState"]:
193                success = expect_state.eval(steps)
194                p = 0 if success else self.penalty_incorrect_program_state
195
196                meta = "expected {}: {}".format(
197                    "{} times".format(expect_state.times)
198                    if expect_state.times >= 0
199                    else "at least once",
200                    expect_state.program_state_text,
201                )
202
203                if success:
204                    meta = "<g>{}</>".format(meta)
205
206                maximum_possible_penalty = self.penalty_incorrect_program_state
207                maximum_possible_penalty_all += maximum_possible_penalty
208                name = expect_state.program_state_text
209                penalties[meta] = [
210                    PenaltyInstance("{} times".format(len(expect_state.encounters)), p)
211                ]
212            self.penalties["expected program states"] = PenaltyCommand(
213                penalties, maximum_possible_penalty_all
214            )
215        except KeyError:
216            pass
217
218        # Get the total number of each step kind.
219        step_kind_counts = defaultdict(int)
220        for step in getattr(steps, "steps"):
221            step_kind_counts[step.step_kind] += 1
222
223        # Get DexExpectStepKind results.
224        penalties = defaultdict(list)
225        maximum_possible_penalty_all = 0
226        try:
227            for command in steps.commands["DexExpectStepKind"]:
228                command.eval()
229                # Cap the penalty at 2 * expected count or else 1
230                maximum_possible_penalty = max(command.count * 2, 1)
231                p = abs(command.count - step_kind_counts[command.name])
232                actual_penalty = min(p, maximum_possible_penalty)
233                key = (
234                    "{}".format(command.name)
235                    if actual_penalty
236                    else "<g>{}</>".format(command.name)
237                )
238                penalties[key] = [PenaltyInstance(p, actual_penalty)]
239                maximum_possible_penalty_all += maximum_possible_penalty
240            self.penalties["step kind differences"] = PenaltyCommand(
241                penalties, maximum_possible_penalty_all
242            )
243        except KeyError:
244            pass
245
246        if "DexUnreachable" in steps.commands:
247            cmds = steps.commands["DexUnreachable"]
248            unreach_count = 0
249
250            # Find steps with unreachable in them
251            ureachs = [s for s in steps.steps if "DexUnreachable" in s.watches.keys()]
252
253            # There's no need to match up cmds with the actual watches
254            upen = self.penalty_unreachable
255
256            count = upen * len(ureachs)
257            if count != 0:
258                d = dict()
259                for x in ureachs:
260                    msg = "line {} reached".format(x.current_location.lineno)
261                    d[msg] = [PenaltyInstance(upen, upen)]
262            else:
263                d = {"<g>No unreachable lines seen</>": [PenaltyInstance(0, 0)]}
264            total = PenaltyCommand(d, len(cmds) * upen)
265
266            self.penalties["unreachable lines"] = total
267
268        if "DexExpectStepOrder" in steps.commands:
269            cmds = steps.commands["DexExpectStepOrder"]
270
271            # Form a list of which line/cmd we _should_ have seen
272            cmd_num_lst = [(x, c.get_line()) for c in cmds for x in c.sequence]
273            # Order them by the sequence number
274            cmd_num_lst.sort(key=lambda t: t[0])
275            # Strip out sequence key
276            cmd_num_lst = [y for x, y in cmd_num_lst]
277
278            # Now do the same, but for the actually observed lines/cmds
279            ss = steps.steps
280            deso = [s for s in ss if "DexExpectStepOrder" in s.watches.keys()]
281            deso = [s.watches["DexExpectStepOrder"] for s in deso]
282            # We rely on the steps remaining in order here
283            order_list = [int(x.expression) for x in deso]
284
285            # First off, check to see whether or not there are missing items
286            expected = Counter(cmd_num_lst)
287            seen = Counter(order_list)
288
289            unseen_line_dict = dict()
290            skipped_line_dict = dict()
291
292            mispen = self.penalty_missing_step
293            num_missing = 0
294            num_repeats = 0
295            for k, v in expected.items():
296                if k not in seen:
297                    msg = "Line {} not seen".format(k)
298                    unseen_line_dict[msg] = [PenaltyInstance(mispen, mispen)]
299                    num_missing += v
300                elif v > seen[k]:
301                    msg = "Line {} skipped at least once".format(k)
302                    skipped_line_dict[msg] = [PenaltyInstance(mispen, mispen)]
303                    num_missing += v - seen[k]
304                elif v < seen[k]:
305                    # Don't penalise unexpected extra sightings of a line
306                    # for now
307                    num_repeats = seen[k] - v
308                    pass
309
310            if len(unseen_line_dict) == 0:
311                pi = PenaltyInstance(0, 0)
312                unseen_line_dict["<g>All lines were seen</>"] = [pi]
313
314            if len(skipped_line_dict) == 0:
315                pi = PenaltyInstance(0, 0)
316                skipped_line_dict["<g>No lines were skipped</>"] = [pi]
317
318            total = PenaltyCommand(unseen_line_dict, len(expected) * mispen)
319            self.penalties["Unseen lines"] = total
320            total = PenaltyCommand(skipped_line_dict, len(expected) * mispen)
321            self.penalties["Skipped lines"] = total
322
323            ordpen = self.penalty_misordered_steps
324            cmd_num_lst = [str(x) for x in cmd_num_lst]
325            order_list = [str(x) for x in order_list]
326            lst = list(difflib.Differ().compare(cmd_num_lst, order_list))
327            diff_detail = Counter(l[0] for l in lst)
328
329            assert "?" not in diff_detail
330
331            # Diffs are hard to interpret; there are many algorithms for
332            # condensing them. Ignore all that, and just print out the changed
333            # sequences, it's up to the user to interpret what's going on.
334
335            def filt_lines(s, seg, e, key):
336                lst = [s]
337                for x in seg:
338                    if x[0] == key:
339                        lst.append(int(x[2:]))
340                lst.append(e)
341                return lst
342
343            diff_msgs = dict()
344
345            def reportdiff(start_idx, segment, end_idx):
346                msg = "Order mismatch, expected linenos {}, saw {}"
347                expected_linenos = filt_lines(start_idx, segment, end_idx, "-")
348                seen_linenos = filt_lines(start_idx, segment, end_idx, "+")
349                msg = msg.format(expected_linenos, seen_linenos)
350                diff_msgs[msg] = [PenaltyInstance(ordpen, ordpen)]
351
352            # Group by changed segments.
353            start_expt_step = 0
354            end_expt_step = 0
355            to_print_lst = []
356            for k, subit in groupby(lst, lambda x: x[0] == " "):
357                if k:  # Whitespace group
358                    nochanged = [x for x in subit]
359                    end_expt_step = int(nochanged[0][2:])
360                    if len(to_print_lst) > 0:
361                        reportdiff(start_expt_step, to_print_lst, end_expt_step)
362                    start_expt_step = int(nochanged[-1][2:])
363                    to_print_lst = []
364                else:  # Diff group, save for printing
365                    to_print_lst = [x for x in subit]
366
367            # If there was a dangling different step, print that too.
368            if len(to_print_lst) > 0:
369                reportdiff(start_expt_step, to_print_lst, "[End]")
370
371            if len(diff_msgs) == 0:
372                diff_msgs["<g>No lines misordered</>"] = [PenaltyInstance(0, 0)]
373            total = PenaltyCommand(diff_msgs, len(cmd_num_lst) * ordpen)
374            self.penalties["Misordered lines"] = total
375
376        return
377
378    def _calculate_expect_watch_penalties(self, c, maximum_possible_penalty):
379        penalties = defaultdict(list)
380
381        if c.line_range[0] == c.line_range[-1]:
382            line_range = str(c.line_range[0])
383        else:
384            line_range = "{}-{}".format(c.line_range[0], c.line_range[-1])
385
386        name = "{}:{} [{}]".format(os.path.basename(c.path), line_range, c.expression)
387
388        num_actual_watches = len(c.expected_watches) + len(c.unexpected_watches)
389
390        penalty_available = maximum_possible_penalty
391
392        # Only penalize for missing values if we have actually seen a watch
393        # that's returned us an actual value at some point, or if we've not
394        # encountered the value at all.
395        if num_actual_watches or c.times_encountered == 0:
396            for v in c.missing_values:
397                current_penalty = min(penalty_available, self.penalty_missing_values)
398                penalty_available -= current_penalty
399                penalties["missing values"].append(PenaltyInstance(v, current_penalty))
400
401        for v in c.encountered_values:
402            penalties["<g>expected encountered watches</>"].append(
403                PenaltyInstance(v, 0)
404            )
405
406        penalty_descriptions = [
407            (self.penalty_not_evaluatable, c.invalid_watches, "could not evaluate"),
408            (
409                self.penalty_variable_optimized,
410                c.optimized_out_watches,
411                "result optimized away",
412            ),
413            (self.penalty_misordered_values, c.misordered_watches, "misordered result"),
414            (
415                self.penalty_irretrievable,
416                c.irretrievable_watches,
417                "result could not be retrieved",
418            ),
419            (self.penalty_incorrect_values, c.unexpected_watches, "unexpected result"),
420        ]
421
422        for penalty_score, watches, description in penalty_descriptions:
423            # We only penalize the encountered issue for each missing value per
424            # command but we still want to record each one, so set the penalty
425            # to 0 after the threshold is passed.
426            times_to_penalize = len(c.missing_values)
427
428            for w in watches:
429                times_to_penalize -= 1
430                penalty_score = min(penalty_available, penalty_score)
431                penalty_available -= penalty_score
432                penalties[description].append(PenaltyInstance(w, penalty_score))
433                if not times_to_penalize:
434                    penalty_score = 0
435
436        return name, penalties
437
438    @property
439    def penalty(self):
440        result = 0
441
442        maximum_allowed_penalty = 0
443        for name, pen_cmd in self.penalties.items():
444            maximum_allowed_penalty += pen_cmd.max_penalty
445            value = pen_cmd.pen_dict
446            for category, inst_list in value.items():
447                result += sum(x.the_penalty for x in inst_list)
448        return min(result, maximum_allowed_penalty)
449
450    @property
451    def max_penalty(self):
452        return sum(p_cat.max_penalty for p_cat in self.penalties.values())
453
454    @property
455    def score(self):
456        try:
457            return 1.0 - (self.penalty / float(self.max_penalty))
458        except ZeroDivisionError:
459            return float("nan")
460
461    @property
462    def summary_string(self):
463        score = self.score
464        isnan = score != score  # pylint: disable=comparison-with-itself
465        color = "g"
466        if score < 0.25 or isnan:
467            color = "r"
468        elif score < 0.75:
469            color = "y"
470
471        return "<{}>({:.4f})</>".format(color, score)
472
473    @property
474    def verbose_output(self):  # noqa
475        string = ""
476
477        # Add address resolutions if present.
478        if self.address_resolutions:
479            if self.resolved_addresses:
480                string += "\nResolved Addresses:\n"
481                for addr, res in self.resolved_addresses.items():
482                    string += f"  '{addr}': {res}\n"
483            if self.unresolved_addresses:
484                string += "\n"
485                string += f"Unresolved Addresses:\n  {self.unresolved_addresses}\n"
486
487        string += "\n"
488        for command in sorted(self.penalties):
489            pen_cmd = self.penalties[command]
490            maximum_possible_penalty = pen_cmd.max_penalty
491            total_penalty = 0
492            lines = []
493            for category in sorted(pen_cmd.pen_dict):
494                lines.append("    <r>{}</>:\n".format(category))
495
496                step_value_results = {}
497                for result, penalty in pen_cmd.pen_dict[category]:
498                    if not isinstance(result, StepValueInfo):
499                        continue
500                    if result.expected_value not in step_value_results:
501                        step_value_results[result.expected_value] = PenaltyLineRanges(
502                            result.step_index, penalty
503                        )
504                    else:
505                        step_value_results[result.expected_value].add_step(
506                            result.step_index, penalty
507                        )
508
509                for value, penalty_line_range in step_value_results.items():
510                    text = f"({value}): {penalty_line_range}"
511                    total_penalty += penalty_line_range.penalty
512                    lines.append("      {}\n".format(text))
513
514                for result, penalty in pen_cmd.pen_dict[category]:
515                    if isinstance(result, StepValueInfo):
516                        continue
517                    else:
518                        text = str(result)
519                    if penalty:
520                        assert penalty > 0, penalty
521                        total_penalty += penalty
522                        text += " <r>[-{}]</>".format(penalty)
523                    lines.append("      {}\n".format(text))
524
525                lines.append("\n")
526
527            string += "  <b>{}</> <y>[{}/{}]</>\n".format(
528                command, total_penalty, maximum_possible_penalty
529            )
530            for line in lines:
531                string += line
532        string += "\n"
533        return string
534
535    @property
536    def resolved_addresses(self):
537        return {
538            addr: format_address(res)
539            for addr, res in self.address_resolutions.items()
540            if res is not None
541        }
542
543    @property
544    def unresolved_addresses(self):
545        return [addr for addr, res in self.address_resolutions.items() if res is None]
546
547    @property
548    def penalty_variable_optimized(self):
549        return self.context.options.penalty_variable_optimized
550
551    @property
552    def penalty_irretrievable(self):
553        return self.context.options.penalty_irretrievable
554
555    @property
556    def penalty_not_evaluatable(self):
557        return self.context.options.penalty_not_evaluatable
558
559    @property
560    def penalty_incorrect_values(self):
561        return self.context.options.penalty_incorrect_values
562
563    @property
564    def penalty_missing_values(self):
565        return self.context.options.penalty_missing_values
566
567    @property
568    def penalty_misordered_values(self):
569        return self.context.options.penalty_misordered_values
570
571    @property
572    def penalty_unreachable(self):
573        return self.context.options.penalty_unreachable
574
575    @property
576    def penalty_missing_step(self):
577        return self.context.options.penalty_missing_step
578
579    @property
580    def penalty_misordered_steps(self):
581        return self.context.options.penalty_misordered_steps
582
583    @property
584    def penalty_incorrect_program_state(self):
585        return self.context.options.penalty_incorrect_program_state
586