xref: /llvm-project/mlir/utils/jupyter/mlir_opt_kernel/kernel.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from subprocess import Popen
6import os
7import subprocess
8import tempfile
9import traceback
10from ipykernel.kernelbase import Kernel
11
12__version__ = "0.0.1"
13
14
15def _get_executable():
16    """Find the mlir-opt executable."""
17
18    def is_exe(fpath):
19        """Returns whether executable file."""
20        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
21
22    program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
23    path, name = os.path.split(program)
24    # Attempt to get the executable
25    if path:
26        if is_exe(program):
27            return program
28    else:
29        for path in os.environ["PATH"].split(os.pathsep):
30            file = os.path.join(path, name)
31            if is_exe(file):
32                return file
33    raise OSError("mlir-opt not found, please see README")
34
35
36class MlirOptKernel(Kernel):
37    """Kernel using mlir-opt inside jupyter.
38
39    The reproducer syntax (`// configuration:`) is used to run passes. The
40    previous result can be referenced to by using `_` (this variable is reset
41    upon error). E.g.,
42
43    ```mlir
44    // configuration: --pass
45    func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... }
46    ```
47
48    ```mlir
49    // configuration: --next-pass
50    _
51    ```
52    """
53
54    implementation = "mlir"
55    implementation_version = __version__
56
57    language_version = __version__
58    language = "mlir"
59    language_info = {
60        "name": "mlir",
61        "codemirror_mode": {"name": "mlir"},
62        "mimetype": "text/x-mlir",
63        "file_extension": ".mlir",
64        "pygments_lexer": "text",
65    }
66
67    @property
68    def banner(self):
69        """Returns kernel banner."""
70        # Just a placeholder.
71        return "mlir-opt kernel %s" % __version__
72
73    def __init__(self, **kwargs):
74        Kernel.__init__(self, **kwargs)
75        self._ = None
76        self.executable = None
77        self.silent = False
78
79    def get_executable(self):
80        """Returns the mlir-opt executable path."""
81        if not self.executable:
82            self.executable = _get_executable()
83        return self.executable
84
85    def process_output(self, output):
86        """Reports regular command output."""
87        if not self.silent:
88            # Send standard output
89            stream_content = {"name": "stdout", "text": output}
90            self.send_response(self.iopub_socket, "stream", stream_content)
91
92    def process_error(self, output):
93        """Reports error response."""
94        if not self.silent:
95            # Send standard error
96            stream_content = {"name": "stderr", "text": output}
97            self.send_response(self.iopub_socket, "stream", stream_content)
98
99    def do_execute(
100        self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
101    ):
102        """Execute user code using mlir-opt binary."""
103
104        def ok_status():
105            """Returns OK status."""
106            return {
107                "status": "ok",
108                "execution_count": self.execution_count,
109                "payload": [],
110                "user_expressions": {},
111            }
112
113        def run(code):
114            """Run the code by pipeing via filesystem."""
115            try:
116                inputmlir = tempfile.NamedTemporaryFile(delete=False)
117                command = [
118                    # Specify input and output file to error out if also
119                    # set as arg.
120                    self.get_executable(),
121                    "--color",
122                    inputmlir.name,
123                    "-o",
124                    "-",
125                ]
126                # Simple handling of repeating last line.
127                if code.endswith("\n_"):
128                    if not self._:
129                        raise NameError("No previous result set")
130                    code = code[:-1] + self._
131                inputmlir.write(code.encode("utf-8"))
132                inputmlir.close()
133                pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
134                output, errors = pipe.communicate()
135                exitcode = pipe.returncode
136            finally:
137                os.unlink(inputmlir.name)
138
139            # Replace temporary filename with placeholder. This takes the very
140            # remote chance where the full input filename (generated above)
141            # overlaps with something in the dump unrelated to the file.
142            fname = inputmlir.name.encode("utf-8")
143            output = output.replace(fname, b"<<input>>")
144            errors = errors.replace(fname, b"<<input>>")
145            return output, errors, exitcode
146
147        self.silent = silent
148        if not code.strip():
149            return ok_status()
150
151        try:
152            output, errors, exitcode = run(code)
153
154            if exitcode:
155                self._ = None
156            else:
157                self._ = output.decode("utf-8")
158        except KeyboardInterrupt:
159            return {"status": "abort", "execution_count": self.execution_count}
160        except Exception as error:
161            # Print traceback for local debugging.
162            traceback.print_exc()
163            self._ = None
164            exitcode = 255
165            errors = repr(error).encode("utf-8")
166
167        if exitcode:
168            content = {"ename": "", "evalue": str(exitcode), "traceback": []}
169
170            self.send_response(self.iopub_socket, "error", content)
171            self.process_error(errors.decode("utf-8"))
172
173            content["execution_count"] = self.execution_count
174            content["status"] = "error"
175            return content
176
177        if not silent:
178            data = {}
179            data["text/x-mlir"] = self._
180            content = {
181                "execution_count": self.execution_count,
182                "data": data,
183                "metadata": {},
184            }
185            self.send_response(self.iopub_socket, "execute_result", content)
186            self.process_output(self._)
187            self.process_error(errors.decode("utf-8"))
188        return ok_status()
189