xref: /llvm-project/mlir/utils/mbr/mbr/main.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1"""This file contains the main function that's called by the CLI of the library.
2"""
3
4import os
5import sys
6import time
7
8import numpy as np
9
10from discovery import discover_benchmark_modules, get_benchmark_functions
11from stats import has_enough_measurements
12
13
14def main(top_level_path, stop_on_error):
15    """Top level function called when the CLI is invoked."""
16    if "::" in top_level_path:
17        if top_level_path.count("::") > 1:
18            raise AssertionError(f"Invalid path {top_level_path}")
19        top_level_path, benchmark_function_name = top_level_path.split("::")
20    else:
21        benchmark_function_name = None
22
23    if not os.path.exists(top_level_path):
24        raise AssertionError(f"The top-level path {top_level_path} doesn't exist")
25
26    modules = [module for module in discover_benchmark_modules(top_level_path)]
27    benchmark_dicts = []
28    for module in modules:
29        benchmark_functions = [
30            function
31            for function in get_benchmark_functions(module, benchmark_function_name)
32        ]
33        for benchmark_function in benchmark_functions:
34            try:
35                compiler, runner = benchmark_function()
36            except (TypeError, ValueError) as e:
37                error_message = (
38                    f"Obtaining compiler and runner failed because of {e}."
39                    f" Benchmark function '{benchmark_function.__name__}'"
40                    f" must return a two-tuple value (compiler, runner)."
41                )
42                if stop_on_error is False:
43                    print(error_message, file=sys.stderr)
44                    continue
45                else:
46                    raise AssertionError(error_message) from e
47            measurements_ns = np.array([])
48            if compiler:
49                start_compile_time_s = time.time()
50                try:
51                    compiled_callable = compiler()
52                except Exception as e:
53                    error_message = (
54                        f"Compilation of {benchmark_function.__name__} failed"
55                        f" because of {e}"
56                    )
57                    if stop_on_error is False:
58                        print(error_message, file=sys.stderr)
59                        continue
60                    else:
61                        raise AssertionError(error_message) from e
62                total_compile_time_s = time.time() - start_compile_time_s
63                runner_args = (compiled_callable,)
64            else:
65                total_compile_time_s = 0
66                runner_args = ()
67            while not has_enough_measurements(measurements_ns):
68                try:
69                    measurement_ns = runner(*runner_args)
70                except Exception as e:
71                    error_message = (
72                        f"Runner of {benchmark_function.__name__} failed"
73                        f" because of {e}"
74                    )
75                    if stop_on_error is False:
76                        print(error_message, file=sys.stderr)
77                        # Recover from runner error by breaking out of this loop
78                        # and continuing forward.
79                        break
80                    else:
81                        raise AssertionError(error_message) from e
82                if not isinstance(measurement_ns, int):
83                    error_message = (
84                        f"Expected benchmark runner function"
85                        f" to return an int, got {measurement_ns}"
86                    )
87                    if stop_on_error is False:
88                        print(error_message, file=sys.stderr)
89                        continue
90                    else:
91                        raise AssertionError(error_message)
92                measurements_ns = np.append(measurements_ns, measurement_ns)
93
94            if len(measurements_ns) > 0:
95                measurements_s = [t * 1e-9 for t in measurements_ns]
96                benchmark_identifier = ":".join(
97                    [module.__name__, benchmark_function.__name__]
98                )
99                benchmark_dicts.append(
100                    {
101                        "name": benchmark_identifier,
102                        "compile_time": total_compile_time_s,
103                        "execution_time": list(measurements_s),
104                    }
105                )
106
107    return benchmark_dicts
108