xref: /llvm-project/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py (revision d95e6d027486876559f1a2a96c33b8ad93cc0ae4)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#  This file contains the Nvgpu class.
6
7from mlir import execution_engine
8from mlir import ir
9from mlir import passmanager
10from typing import Sequence
11import errno
12import os
13import sys
14
15_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
16sys.path.append(_SCRIPT_PATH)
17
18
19class NvgpuCompiler:
20    """Nvgpu class for compiling and building MLIR modules."""
21
22    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
23        pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
24        self.pipeline = pipeline
25        self.shared_libs = shared_libs
26        self.opt_level = opt_level
27
28    def __call__(self, module: ir.Module):
29        """Convenience application method."""
30        self.compile(module)
31
32    def compile(self, module: ir.Module):
33        """Compiles the module by invoking the nvgpu pipeline."""
34        passmanager.PassManager.parse(self.pipeline).run(module.operation)
35
36    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
37        """Wraps the module in a JIT execution engine."""
38        return execution_engine.ExecutionEngine(
39            module, opt_level=self.opt_level, shared_libs=self.shared_libs
40        )
41
42    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
43        """Compiles and jits the module."""
44        self.compile(module)
45        return self.jit(module)
46