1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains the Nvgpu class. 6 7from mlir import execution_engine 8from mlir import ir 9from mlir import passmanager 10from typing import Sequence 11import errno 12import os 13import sys 14 15_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) 16sys.path.append(_SCRIPT_PATH) 17 18 19class NvgpuCompiler: 20 """Nvgpu class for compiling and building MLIR modules.""" 21 22 def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): 23 pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})" 24 self.pipeline = pipeline 25 self.shared_libs = shared_libs 26 self.opt_level = opt_level 27 28 def __call__(self, module: ir.Module): 29 """Convenience application method.""" 30 self.compile(module) 31 32 def compile(self, module: ir.Module): 33 """Compiles the module by invoking the nvgpu pipeline.""" 34 passmanager.PassManager.parse(self.pipeline).run(module.operation) 35 36 def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: 37 """Wraps the module in a JIT execution engine.""" 38 return execution_engine.ExecutionEngine( 39 module, opt_level=self.opt_level, shared_libs=self.shared_libs 40 ) 41 42 def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: 43 """Compiles and jits the module.""" 44 self.compile(module) 45 return self.jit(module) 46