xref: /llvm-project/mlir/benchmark/python/common.py (revision a05e20b9720f8b012f06f410d92f1f22b55bce74)
1fa90c9d5SSaurabh Jha"""Common utilities that are useful for all the benchmarks."""
2fa90c9d5SSaurabh Jhaimport numpy as np
3fa90c9d5SSaurabh Jha
4fa90c9d5SSaurabh Jhafrom mlir import ir
5fa90c9d5SSaurabh Jhafrom mlir.dialects import arith
623aa5a74SRiver Riddlefrom mlir.dialects import func
7fa90c9d5SSaurabh Jhafrom mlir.dialects import memref
8fa90c9d5SSaurabh Jhafrom mlir.dialects import scf
9fa90c9d5SSaurabh Jhafrom mlir.passmanager import PassManager
10fa90c9d5SSaurabh Jha
11fa90c9d5SSaurabh Jha
12fa90c9d5SSaurabh Jhadef setup_passes(mlir_module):
13f9008e63STobias Hieta    """Setup pass pipeline parameters for benchmark functions."""
14fa90c9d5SSaurabh Jha    opt = (
1530ceb783SNick Kreeger        "parallelization-strategy=none"
16fa90c9d5SSaurabh Jha    )
17*a05e20b9SKohei Yamaguchi    pipeline = f"builtin.module(sparsifier{{{opt}}})"
18*a05e20b9SKohei Yamaguchi    PassManager.parse(pipeline).run(mlir_module.operation)
19fa90c9d5SSaurabh Jha
20fa90c9d5SSaurabh Jha
21fa90c9d5SSaurabh Jhadef create_sparse_np_tensor(dimensions, number_of_elements):
22fa90c9d5SSaurabh Jha    """Constructs a numpy tensor of dimensions `dimensions` that has only a
23fa90c9d5SSaurabh Jha    specific number of nonzero elements, specified by the `number_of_elements`
24fa90c9d5SSaurabh Jha    argument.
25fa90c9d5SSaurabh Jha    """
26fa90c9d5SSaurabh Jha    tensor = np.zeros(dimensions, np.float64)
27fa90c9d5SSaurabh Jha    tensor_indices_list = [
28fa90c9d5SSaurabh Jha        [np.random.randint(0, dimension) for dimension in dimensions]
29fa90c9d5SSaurabh Jha        for _ in range(number_of_elements)
30fa90c9d5SSaurabh Jha    ]
31fa90c9d5SSaurabh Jha    for tensor_indices in tensor_indices_list:
32fa90c9d5SSaurabh Jha        current_tensor = tensor
33fa90c9d5SSaurabh Jha        for tensor_index in tensor_indices[:-1]:
34fa90c9d5SSaurabh Jha            current_tensor = current_tensor[tensor_index]
35fa90c9d5SSaurabh Jha        current_tensor[tensor_indices[-1]] = np.random.uniform(1, 100)
36fa90c9d5SSaurabh Jha    return tensor
37fa90c9d5SSaurabh Jha
38fa90c9d5SSaurabh Jha
3936550692SRiver Riddledef get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
40fa90c9d5SSaurabh Jha    """Takes an mlir module object and extracts the function object out of it.
41fa90c9d5SSaurabh Jha    This function only works for a module with one region, one block, and one
42fa90c9d5SSaurabh Jha    operation.
43fa90c9d5SSaurabh Jha    """
44f9008e63STobias Hieta    assert (
45f9008e63STobias Hieta        len(module.operation.regions) == 1
46f9008e63STobias Hieta    ), "Expected kernel module to have only one region"
47f9008e63STobias Hieta    assert (
48f9008e63STobias Hieta        len(module.operation.regions[0].blocks) == 1
49f9008e63STobias Hieta    ), "Expected kernel module to have only one block"
50f9008e63STobias Hieta    assert (
51f9008e63STobias Hieta        len(module.operation.regions[0].blocks[0].operations) == 1
52f9008e63STobias Hieta    ), "Expected kernel module to have only one operation"
53fa90c9d5SSaurabh Jha    return module.operation.regions[0].blocks[0].operations[0]
54fa90c9d5SSaurabh Jha
55fa90c9d5SSaurabh Jha
5636550692SRiver Riddledef emit_timer_func() -> func.FuncOp:
5789d49045SDenys Shabalin    """Returns the declaration of nanoTime function. If nanoTime function is
58fa90c9d5SSaurabh Jha    used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
59fa90c9d5SSaurabh Jha    """
60fa90c9d5SSaurabh Jha    i64_type = ir.IntegerType.get_signless(64)
61f9008e63STobias Hieta    nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private")
6289d49045SDenys Shabalin    nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
6389d49045SDenys Shabalin    return nanoTime
64fa90c9d5SSaurabh Jha
65fa90c9d5SSaurabh Jha
665da5483fSIngo Müllerdef emit_benchmark_wrapped_main_func(kernel_func, timer_func):
67fa90c9d5SSaurabh Jha    """Takes a function and a timer function, both represented as FuncOp
68fa90c9d5SSaurabh Jha    objects, and returns a new function. This new function wraps the call to
69fa90c9d5SSaurabh Jha    the original function between calls to the timer_func and this wrapping
70fa90c9d5SSaurabh Jha    in turn is executed inside a loop. The loop is executed
715da5483fSIngo Müller    len(kernel_func.type.results) times. This function can be used to
725da5483fSIngo Müller    create a "time measuring" variant of a function.
73fa90c9d5SSaurabh Jha    """
74fa90c9d5SSaurabh Jha    i64_type = ir.IntegerType.get_signless(64)
75*a05e20b9SKohei Yamaguchi    memref_of_i64_type = ir.MemRefType.get([ir.ShapedType.get_dynamic_size()], i64_type)
7636550692SRiver Riddle    wrapped_func = func.FuncOp(
77fa90c9d5SSaurabh Jha        # Same signature and an extra buffer of indices to save timings.
78fa90c9d5SSaurabh Jha        "main",
79f9008e63STobias Hieta        (kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results),
80f9008e63STobias Hieta        visibility="public",
81fa90c9d5SSaurabh Jha    )
82fa90c9d5SSaurabh Jha    wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
83fa90c9d5SSaurabh Jha
845da5483fSIngo Müller    num_results = len(kernel_func.type.results)
85fa90c9d5SSaurabh Jha    with ir.InsertionPoint(wrapped_func.add_entry_block()):
86fa90c9d5SSaurabh Jha        timer_buffer = wrapped_func.arguments[-1]
87fa90c9d5SSaurabh Jha        zero = arith.ConstantOp.create_index(0)
88*a05e20b9SKohei Yamaguchi        n_iterations = memref.DimOp(timer_buffer, zero)
89fa90c9d5SSaurabh Jha        one = arith.ConstantOp.create_index(1)
90fa90c9d5SSaurabh Jha        iter_args = list(wrapped_func.arguments[-num_results - 1 : -1])
91fa90c9d5SSaurabh Jha        loop = scf.ForOp(zero, n_iterations, one, iter_args)
92fa90c9d5SSaurabh Jha        with ir.InsertionPoint(loop.body):
9323aa5a74SRiver Riddle            start = func.CallOp(timer_func, [])
9423aa5a74SRiver Riddle            call = func.CallOp(
955da5483fSIngo Müller                kernel_func,
96f9008e63STobias Hieta                wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args,
97fa90c9d5SSaurabh Jha            )
9823aa5a74SRiver Riddle            end = func.CallOp(timer_func, [])
99fa90c9d5SSaurabh Jha            time_taken = arith.SubIOp(end, start)
100fa90c9d5SSaurabh Jha            memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable])
101fa90c9d5SSaurabh Jha            scf.YieldOp(list(call.results))
10223aa5a74SRiver Riddle        func.ReturnOp(loop)
103fa90c9d5SSaurabh Jha
104fa90c9d5SSaurabh Jha    return wrapped_func
105