1fa90c9d5SSaurabh Jha"""Common utilities that are useful for all the benchmarks.""" 2fa90c9d5SSaurabh Jhaimport numpy as np 3fa90c9d5SSaurabh Jha 4fa90c9d5SSaurabh Jhafrom mlir import ir 5fa90c9d5SSaurabh Jhafrom mlir.dialects import arith 623aa5a74SRiver Riddlefrom mlir.dialects import func 7fa90c9d5SSaurabh Jhafrom mlir.dialects import memref 8fa90c9d5SSaurabh Jhafrom mlir.dialects import scf 9fa90c9d5SSaurabh Jhafrom mlir.passmanager import PassManager 10fa90c9d5SSaurabh Jha 11fa90c9d5SSaurabh Jha 12fa90c9d5SSaurabh Jhadef setup_passes(mlir_module): 13f9008e63STobias Hieta """Setup pass pipeline parameters for benchmark functions.""" 14fa90c9d5SSaurabh Jha opt = ( 1530ceb783SNick Kreeger "parallelization-strategy=none" 16fa90c9d5SSaurabh Jha ) 17*a05e20b9SKohei Yamaguchi pipeline = f"builtin.module(sparsifier{{{opt}}})" 18*a05e20b9SKohei Yamaguchi PassManager.parse(pipeline).run(mlir_module.operation) 19fa90c9d5SSaurabh Jha 20fa90c9d5SSaurabh Jha 21fa90c9d5SSaurabh Jhadef create_sparse_np_tensor(dimensions, number_of_elements): 22fa90c9d5SSaurabh Jha """Constructs a numpy tensor of dimensions `dimensions` that has only a 23fa90c9d5SSaurabh Jha specific number of nonzero elements, specified by the `number_of_elements` 24fa90c9d5SSaurabh Jha argument. 25fa90c9d5SSaurabh Jha """ 26fa90c9d5SSaurabh Jha tensor = np.zeros(dimensions, np.float64) 27fa90c9d5SSaurabh Jha tensor_indices_list = [ 28fa90c9d5SSaurabh Jha [np.random.randint(0, dimension) for dimension in dimensions] 29fa90c9d5SSaurabh Jha for _ in range(number_of_elements) 30fa90c9d5SSaurabh Jha ] 31fa90c9d5SSaurabh Jha for tensor_indices in tensor_indices_list: 32fa90c9d5SSaurabh Jha current_tensor = tensor 33fa90c9d5SSaurabh Jha for tensor_index in tensor_indices[:-1]: 34fa90c9d5SSaurabh Jha current_tensor = current_tensor[tensor_index] 35fa90c9d5SSaurabh Jha current_tensor[tensor_indices[-1]] = np.random.uniform(1, 100) 36fa90c9d5SSaurabh Jha return tensor 37fa90c9d5SSaurabh Jha 38fa90c9d5SSaurabh Jha 3936550692SRiver Riddledef get_kernel_func_from_module(module: ir.Module) -> func.FuncOp: 40fa90c9d5SSaurabh Jha """Takes an mlir module object and extracts the function object out of it. 41fa90c9d5SSaurabh Jha This function only works for a module with one region, one block, and one 42fa90c9d5SSaurabh Jha operation. 43fa90c9d5SSaurabh Jha """ 44f9008e63STobias Hieta assert ( 45f9008e63STobias Hieta len(module.operation.regions) == 1 46f9008e63STobias Hieta ), "Expected kernel module to have only one region" 47f9008e63STobias Hieta assert ( 48f9008e63STobias Hieta len(module.operation.regions[0].blocks) == 1 49f9008e63STobias Hieta ), "Expected kernel module to have only one block" 50f9008e63STobias Hieta assert ( 51f9008e63STobias Hieta len(module.operation.regions[0].blocks[0].operations) == 1 52f9008e63STobias Hieta ), "Expected kernel module to have only one operation" 53fa90c9d5SSaurabh Jha return module.operation.regions[0].blocks[0].operations[0] 54fa90c9d5SSaurabh Jha 55fa90c9d5SSaurabh Jha 5636550692SRiver Riddledef emit_timer_func() -> func.FuncOp: 5789d49045SDenys Shabalin """Returns the declaration of nanoTime function. If nanoTime function is 58fa90c9d5SSaurabh Jha used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included. 59fa90c9d5SSaurabh Jha """ 60fa90c9d5SSaurabh Jha i64_type = ir.IntegerType.get_signless(64) 61f9008e63STobias Hieta nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private") 6289d49045SDenys Shabalin nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 6389d49045SDenys Shabalin return nanoTime 64fa90c9d5SSaurabh Jha 65fa90c9d5SSaurabh Jha 665da5483fSIngo Müllerdef emit_benchmark_wrapped_main_func(kernel_func, timer_func): 67fa90c9d5SSaurabh Jha """Takes a function and a timer function, both represented as FuncOp 68fa90c9d5SSaurabh Jha objects, and returns a new function. This new function wraps the call to 69fa90c9d5SSaurabh Jha the original function between calls to the timer_func and this wrapping 70fa90c9d5SSaurabh Jha in turn is executed inside a loop. The loop is executed 715da5483fSIngo Müller len(kernel_func.type.results) times. This function can be used to 725da5483fSIngo Müller create a "time measuring" variant of a function. 73fa90c9d5SSaurabh Jha """ 74fa90c9d5SSaurabh Jha i64_type = ir.IntegerType.get_signless(64) 75*a05e20b9SKohei Yamaguchi memref_of_i64_type = ir.MemRefType.get([ir.ShapedType.get_dynamic_size()], i64_type) 7636550692SRiver Riddle wrapped_func = func.FuncOp( 77fa90c9d5SSaurabh Jha # Same signature and an extra buffer of indices to save timings. 78fa90c9d5SSaurabh Jha "main", 79f9008e63STobias Hieta (kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results), 80f9008e63STobias Hieta visibility="public", 81fa90c9d5SSaurabh Jha ) 82fa90c9d5SSaurabh Jha wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() 83fa90c9d5SSaurabh Jha 845da5483fSIngo Müller num_results = len(kernel_func.type.results) 85fa90c9d5SSaurabh Jha with ir.InsertionPoint(wrapped_func.add_entry_block()): 86fa90c9d5SSaurabh Jha timer_buffer = wrapped_func.arguments[-1] 87fa90c9d5SSaurabh Jha zero = arith.ConstantOp.create_index(0) 88*a05e20b9SKohei Yamaguchi n_iterations = memref.DimOp(timer_buffer, zero) 89fa90c9d5SSaurabh Jha one = arith.ConstantOp.create_index(1) 90fa90c9d5SSaurabh Jha iter_args = list(wrapped_func.arguments[-num_results - 1 : -1]) 91fa90c9d5SSaurabh Jha loop = scf.ForOp(zero, n_iterations, one, iter_args) 92fa90c9d5SSaurabh Jha with ir.InsertionPoint(loop.body): 9323aa5a74SRiver Riddle start = func.CallOp(timer_func, []) 9423aa5a74SRiver Riddle call = func.CallOp( 955da5483fSIngo Müller kernel_func, 96f9008e63STobias Hieta wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args, 97fa90c9d5SSaurabh Jha ) 9823aa5a74SRiver Riddle end = func.CallOp(timer_func, []) 99fa90c9d5SSaurabh Jha time_taken = arith.SubIOp(end, start) 100fa90c9d5SSaurabh Jha memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable]) 101fa90c9d5SSaurabh Jha scf.YieldOp(list(call.results)) 10223aa5a74SRiver Riddle func.ReturnOp(loop) 103fa90c9d5SSaurabh Jha 104fa90c9d5SSaurabh Jha return wrapped_func 105