xref: /llvm-project/mlir/utils/mbr/mbr/discovery.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1"""This file contains functions for discovering benchmark functions. It works
2in a similar way to python's unittest library.
3"""
4import configparser
5import importlib
6import os
7import pathlib
8import re
9import sys
10import types
11
12
13def discover_benchmark_modules(top_level_path):
14    """Starting from the `top_level_path`, discover python files which contains
15    benchmark functions. It looks for files with a specific prefix, which
16    defaults to "benchmark_"
17    """
18    config = configparser.ConfigParser()
19    config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini"))
20    if "discovery" in config.sections():
21        filename_prefix = config["discovery"]["filename_prefix"]
22    else:
23        filename_prefix = "benchmark_"
24    if re.search(rf"{filename_prefix}.*.py$", top_level_path):
25        # A specific python file so just include that.
26        benchmark_files = [top_level_path]
27    else:
28        # A directory so recursively search for all python files.
29        benchmark_files = pathlib.Path(top_level_path).rglob(f"{filename_prefix}*.py")
30    for benchmark_filename in benchmark_files:
31        benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename))
32        sys.path.append(benchmark_abs_dir)
33        module_file_name = os.path.basename(benchmark_filename)
34        module_name = module_file_name.replace(".py", "")
35        module = importlib.import_module(module_name)
36        yield module
37        sys.path.pop()
38
39
40def get_benchmark_functions(module, benchmark_function_name=None):
41    """Discover benchmark functions in python file. It looks for functions with
42    a specific prefix, which defaults to "benchmark_".
43    """
44    config = configparser.ConfigParser()
45    config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini"))
46    if "discovery" in config.sections():
47        function_prefix = config["discovery"].get("function_prefix")
48    else:
49        function_prefix = "benchmark_"
50
51    module_functions = []
52    for attribute_name in dir(module):
53        attribute = getattr(module, attribute_name)
54        if isinstance(attribute, types.FunctionType) and attribute_name.startswith(
55            function_prefix
56        ):
57            module_functions.append(attribute)
58
59    if benchmark_function_name:
60        # If benchmark_function_name is present, just yield the corresponding
61        # function and nothing else.
62        for function in module_functions:
63            if function.__name__ == benchmark_function_name:
64                yield function
65    else:
66        # If benchmark_function_name is not present, yield all functions.
67        for function in module_functions:
68            yield function
69