1"""This file contains the main function that's called by the CLI of the library. 2""" 3 4import os 5import sys 6import time 7 8import numpy as np 9 10from discovery import discover_benchmark_modules, get_benchmark_functions 11from stats import has_enough_measurements 12 13 14def main(top_level_path, stop_on_error): 15 """Top level function called when the CLI is invoked.""" 16 if "::" in top_level_path: 17 if top_level_path.count("::") > 1: 18 raise AssertionError(f"Invalid path {top_level_path}") 19 top_level_path, benchmark_function_name = top_level_path.split("::") 20 else: 21 benchmark_function_name = None 22 23 if not os.path.exists(top_level_path): 24 raise AssertionError(f"The top-level path {top_level_path} doesn't exist") 25 26 modules = [module for module in discover_benchmark_modules(top_level_path)] 27 benchmark_dicts = [] 28 for module in modules: 29 benchmark_functions = [ 30 function 31 for function in get_benchmark_functions(module, benchmark_function_name) 32 ] 33 for benchmark_function in benchmark_functions: 34 try: 35 compiler, runner = benchmark_function() 36 except (TypeError, ValueError) as e: 37 error_message = ( 38 f"Obtaining compiler and runner failed because of {e}." 39 f" Benchmark function '{benchmark_function.__name__}'" 40 f" must return a two-tuple value (compiler, runner)." 41 ) 42 if stop_on_error is False: 43 print(error_message, file=sys.stderr) 44 continue 45 else: 46 raise AssertionError(error_message) from e 47 measurements_ns = np.array([]) 48 if compiler: 49 start_compile_time_s = time.time() 50 try: 51 compiled_callable = compiler() 52 except Exception as e: 53 error_message = ( 54 f"Compilation of {benchmark_function.__name__} failed" 55 f" because of {e}" 56 ) 57 if stop_on_error is False: 58 print(error_message, file=sys.stderr) 59 continue 60 else: 61 raise AssertionError(error_message) from e 62 total_compile_time_s = time.time() - start_compile_time_s 63 runner_args = (compiled_callable,) 64 else: 65 total_compile_time_s = 0 66 runner_args = () 67 while not has_enough_measurements(measurements_ns): 68 try: 69 measurement_ns = runner(*runner_args) 70 except Exception as e: 71 error_message = ( 72 f"Runner of {benchmark_function.__name__} failed" 73 f" because of {e}" 74 ) 75 if stop_on_error is False: 76 print(error_message, file=sys.stderr) 77 # Recover from runner error by breaking out of this loop 78 # and continuing forward. 79 break 80 else: 81 raise AssertionError(error_message) from e 82 if not isinstance(measurement_ns, int): 83 error_message = ( 84 f"Expected benchmark runner function" 85 f" to return an int, got {measurement_ns}" 86 ) 87 if stop_on_error is False: 88 print(error_message, file=sys.stderr) 89 continue 90 else: 91 raise AssertionError(error_message) 92 measurements_ns = np.append(measurements_ns, measurement_ns) 93 94 if len(measurements_ns) > 0: 95 measurements_s = [t * 1e-9 for t in measurements_ns] 96 benchmark_identifier = ":".join( 97 [module.__name__, benchmark_function.__name__] 98 ) 99 benchmark_dicts.append( 100 { 101 "name": benchmark_identifier, 102 "compile_time": total_compile_time_s, 103 "execution_time": list(measurements_s), 104 } 105 ) 106 107 return benchmark_dicts 108