xref: /llvm-project/mlir/utils/jupyter/mlir_opt_kernel/kernel.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
104c66eddSJacques Pienaar# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
204c66eddSJacques Pienaar# See https://llvm.org/LICENSE.txt for license information.
304c66eddSJacques Pienaar# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
404c66eddSJacques Pienaar
504c66eddSJacques Pienaarfrom subprocess import Popen
604c66eddSJacques Pienaarimport os
704c66eddSJacques Pienaarimport subprocess
804c66eddSJacques Pienaarimport tempfile
904c66eddSJacques Pienaarimport traceback
1004c66eddSJacques Pienaarfrom ipykernel.kernelbase import Kernel
1104c66eddSJacques Pienaar
12*f9008e63STobias Hieta__version__ = "0.0.1"
1304c66eddSJacques Pienaar
1404c66eddSJacques Pienaar
1504c66eddSJacques Pienaardef _get_executable():
1604c66eddSJacques Pienaar    """Find the mlir-opt executable."""
1704c66eddSJacques Pienaar
1804c66eddSJacques Pienaar    def is_exe(fpath):
1904c66eddSJacques Pienaar        """Returns whether executable file."""
2004c66eddSJacques Pienaar        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
2104c66eddSJacques Pienaar
22*f9008e63STobias Hieta    program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
2304c66eddSJacques Pienaar    path, name = os.path.split(program)
2404c66eddSJacques Pienaar    # Attempt to get the executable
2504c66eddSJacques Pienaar    if path:
2604c66eddSJacques Pienaar        if is_exe(program):
2704c66eddSJacques Pienaar            return program
2804c66eddSJacques Pienaar    else:
2904c66eddSJacques Pienaar        for path in os.environ["PATH"].split(os.pathsep):
3004c66eddSJacques Pienaar            file = os.path.join(path, name)
3104c66eddSJacques Pienaar            if is_exe(file):
3204c66eddSJacques Pienaar                return file
33*f9008e63STobias Hieta    raise OSError("mlir-opt not found, please see README")
3404c66eddSJacques Pienaar
3504c66eddSJacques Pienaar
3604c66eddSJacques Pienaarclass MlirOptKernel(Kernel):
3704c66eddSJacques Pienaar    """Kernel using mlir-opt inside jupyter.
3804c66eddSJacques Pienaar
3904c66eddSJacques Pienaar    The reproducer syntax (`// configuration:`) is used to run passes. The
4004c66eddSJacques Pienaar    previous result can be referenced to by using `_` (this variable is reset
4104c66eddSJacques Pienaar    upon error). E.g.,
4204c66eddSJacques Pienaar
4304c66eddSJacques Pienaar    ```mlir
4404c66eddSJacques Pienaar    // configuration: --pass
452310ced8SRiver Riddle    func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... }
4604c66eddSJacques Pienaar    ```
4704c66eddSJacques Pienaar
4804c66eddSJacques Pienaar    ```mlir
4904c66eddSJacques Pienaar    // configuration: --next-pass
5004c66eddSJacques Pienaar    _
5104c66eddSJacques Pienaar    ```
5204c66eddSJacques Pienaar    """
5304c66eddSJacques Pienaar
54*f9008e63STobias Hieta    implementation = "mlir"
5504c66eddSJacques Pienaar    implementation_version = __version__
5604c66eddSJacques Pienaar
5704c66eddSJacques Pienaar    language_version = __version__
5804c66eddSJacques Pienaar    language = "mlir"
5904c66eddSJacques Pienaar    language_info = {
6004c66eddSJacques Pienaar        "name": "mlir",
61*f9008e63STobias Hieta        "codemirror_mode": {"name": "mlir"},
6204c66eddSJacques Pienaar        "mimetype": "text/x-mlir",
6304c66eddSJacques Pienaar        "file_extension": ".mlir",
64*f9008e63STobias Hieta        "pygments_lexer": "text",
6504c66eddSJacques Pienaar    }
6604c66eddSJacques Pienaar
6704c66eddSJacques Pienaar    @property
6804c66eddSJacques Pienaar    def banner(self):
6904c66eddSJacques Pienaar        """Returns kernel banner."""
7004c66eddSJacques Pienaar        # Just a placeholder.
7104c66eddSJacques Pienaar        return "mlir-opt kernel %s" % __version__
7204c66eddSJacques Pienaar
7304c66eddSJacques Pienaar    def __init__(self, **kwargs):
7404c66eddSJacques Pienaar        Kernel.__init__(self, **kwargs)
7504c66eddSJacques Pienaar        self._ = None
7604c66eddSJacques Pienaar        self.executable = None
7704c66eddSJacques Pienaar        self.silent = False
7804c66eddSJacques Pienaar
7904c66eddSJacques Pienaar    def get_executable(self):
8004c66eddSJacques Pienaar        """Returns the mlir-opt executable path."""
8104c66eddSJacques Pienaar        if not self.executable:
8204c66eddSJacques Pienaar            self.executable = _get_executable()
8304c66eddSJacques Pienaar        return self.executable
8404c66eddSJacques Pienaar
8504c66eddSJacques Pienaar    def process_output(self, output):
8604c66eddSJacques Pienaar        """Reports regular command output."""
8704c66eddSJacques Pienaar        if not self.silent:
8804c66eddSJacques Pienaar            # Send standard output
89*f9008e63STobias Hieta            stream_content = {"name": "stdout", "text": output}
90*f9008e63STobias Hieta            self.send_response(self.iopub_socket, "stream", stream_content)
9104c66eddSJacques Pienaar
9204c66eddSJacques Pienaar    def process_error(self, output):
9304c66eddSJacques Pienaar        """Reports error response."""
9404c66eddSJacques Pienaar        if not self.silent:
9504c66eddSJacques Pienaar            # Send standard error
96*f9008e63STobias Hieta            stream_content = {"name": "stderr", "text": output}
97*f9008e63STobias Hieta            self.send_response(self.iopub_socket, "stream", stream_content)
9804c66eddSJacques Pienaar
99*f9008e63STobias Hieta    def do_execute(
100*f9008e63STobias Hieta        self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
101*f9008e63STobias Hieta    ):
10204c66eddSJacques Pienaar        """Execute user code using mlir-opt binary."""
10304c66eddSJacques Pienaar
10404c66eddSJacques Pienaar        def ok_status():
10504c66eddSJacques Pienaar            """Returns OK status."""
10604c66eddSJacques Pienaar            return {
107*f9008e63STobias Hieta                "status": "ok",
108*f9008e63STobias Hieta                "execution_count": self.execution_count,
109*f9008e63STobias Hieta                "payload": [],
110*f9008e63STobias Hieta                "user_expressions": {},
11104c66eddSJacques Pienaar            }
11204c66eddSJacques Pienaar
11304c66eddSJacques Pienaar        def run(code):
11404c66eddSJacques Pienaar            """Run the code by pipeing via filesystem."""
11504c66eddSJacques Pienaar            try:
11604c66eddSJacques Pienaar                inputmlir = tempfile.NamedTemporaryFile(delete=False)
11704c66eddSJacques Pienaar                command = [
11804c66eddSJacques Pienaar                    # Specify input and output file to error out if also
11904c66eddSJacques Pienaar                    # set as arg.
12004c66eddSJacques Pienaar                    self.get_executable(),
121*f9008e63STobias Hieta                    "--color",
12204c66eddSJacques Pienaar                    inputmlir.name,
123*f9008e63STobias Hieta                    "-o",
124*f9008e63STobias Hieta                    "-",
12504c66eddSJacques Pienaar                ]
12604c66eddSJacques Pienaar                # Simple handling of repeating last line.
127*f9008e63STobias Hieta                if code.endswith("\n_"):
12804c66eddSJacques Pienaar                    if not self._:
129*f9008e63STobias Hieta                        raise NameError("No previous result set")
13004c66eddSJacques Pienaar                    code = code[:-1] + self._
13104c66eddSJacques Pienaar                inputmlir.write(code.encode("utf-8"))
13204c66eddSJacques Pienaar                inputmlir.close()
133*f9008e63STobias Hieta                pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13404c66eddSJacques Pienaar                output, errors = pipe.communicate()
13504c66eddSJacques Pienaar                exitcode = pipe.returncode
13604c66eddSJacques Pienaar            finally:
13704c66eddSJacques Pienaar                os.unlink(inputmlir.name)
13804c66eddSJacques Pienaar
13904c66eddSJacques Pienaar            # Replace temporary filename with placeholder. This takes the very
14004c66eddSJacques Pienaar            # remote chance where the full input filename (generated above)
14104c66eddSJacques Pienaar            # overlaps with something in the dump unrelated to the file.
14204c66eddSJacques Pienaar            fname = inputmlir.name.encode("utf-8")
14304c66eddSJacques Pienaar            output = output.replace(fname, b"<<input>>")
14404c66eddSJacques Pienaar            errors = errors.replace(fname, b"<<input>>")
14504c66eddSJacques Pienaar            return output, errors, exitcode
14604c66eddSJacques Pienaar
14704c66eddSJacques Pienaar        self.silent = silent
14804c66eddSJacques Pienaar        if not code.strip():
14904c66eddSJacques Pienaar            return ok_status()
15004c66eddSJacques Pienaar
15104c66eddSJacques Pienaar        try:
15204c66eddSJacques Pienaar            output, errors, exitcode = run(code)
15304c66eddSJacques Pienaar
15404c66eddSJacques Pienaar            if exitcode:
15504c66eddSJacques Pienaar                self._ = None
15604c66eddSJacques Pienaar            else:
15704c66eddSJacques Pienaar                self._ = output.decode("utf-8")
15804c66eddSJacques Pienaar        except KeyboardInterrupt:
159*f9008e63STobias Hieta            return {"status": "abort", "execution_count": self.execution_count}
16004c66eddSJacques Pienaar        except Exception as error:
16104c66eddSJacques Pienaar            # Print traceback for local debugging.
16204c66eddSJacques Pienaar            traceback.print_exc()
16304c66eddSJacques Pienaar            self._ = None
16404c66eddSJacques Pienaar            exitcode = 255
16504c66eddSJacques Pienaar            errors = repr(error).encode("utf-8")
16604c66eddSJacques Pienaar
16704c66eddSJacques Pienaar        if exitcode:
168*f9008e63STobias Hieta            content = {"ename": "", "evalue": str(exitcode), "traceback": []}
16904c66eddSJacques Pienaar
170*f9008e63STobias Hieta            self.send_response(self.iopub_socket, "error", content)
17104c66eddSJacques Pienaar            self.process_error(errors.decode("utf-8"))
17204c66eddSJacques Pienaar
173*f9008e63STobias Hieta            content["execution_count"] = self.execution_count
174*f9008e63STobias Hieta            content["status"] = "error"
17504c66eddSJacques Pienaar            return content
17604c66eddSJacques Pienaar
17704c66eddSJacques Pienaar        if not silent:
17804c66eddSJacques Pienaar            data = {}
179*f9008e63STobias Hieta            data["text/x-mlir"] = self._
18004c66eddSJacques Pienaar            content = {
181*f9008e63STobias Hieta                "execution_count": self.execution_count,
182*f9008e63STobias Hieta                "data": data,
183*f9008e63STobias Hieta                "metadata": {},
18404c66eddSJacques Pienaar            }
185*f9008e63STobias Hieta            self.send_response(self.iopub_socket, "execute_result", content)
18604c66eddSJacques Pienaar            self.process_output(self._)
18704c66eddSJacques Pienaar            self.process_error(errors.decode("utf-8"))
18804c66eddSJacques Pienaar        return ok_status()
189