1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from subprocess import Popen 6import os 7import subprocess 8import tempfile 9import traceback 10from ipykernel.kernelbase import Kernel 11 12__version__ = "0.0.1" 13 14 15def _get_executable(): 16 """Find the mlir-opt executable.""" 17 18 def is_exe(fpath): 19 """Returns whether executable file.""" 20 return os.path.isfile(fpath) and os.access(fpath, os.X_OK) 21 22 program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt") 23 path, name = os.path.split(program) 24 # Attempt to get the executable 25 if path: 26 if is_exe(program): 27 return program 28 else: 29 for path in os.environ["PATH"].split(os.pathsep): 30 file = os.path.join(path, name) 31 if is_exe(file): 32 return file 33 raise OSError("mlir-opt not found, please see README") 34 35 36class MlirOptKernel(Kernel): 37 """Kernel using mlir-opt inside jupyter. 38 39 The reproducer syntax (`// configuration:`) is used to run passes. The 40 previous result can be referenced to by using `_` (this variable is reset 41 upon error). E.g., 42 43 ```mlir 44 // configuration: --pass 45 func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... } 46 ``` 47 48 ```mlir 49 // configuration: --next-pass 50 _ 51 ``` 52 """ 53 54 implementation = "mlir" 55 implementation_version = __version__ 56 57 language_version = __version__ 58 language = "mlir" 59 language_info = { 60 "name": "mlir", 61 "codemirror_mode": {"name": "mlir"}, 62 "mimetype": "text/x-mlir", 63 "file_extension": ".mlir", 64 "pygments_lexer": "text", 65 } 66 67 @property 68 def banner(self): 69 """Returns kernel banner.""" 70 # Just a placeholder. 71 return "mlir-opt kernel %s" % __version__ 72 73 def __init__(self, **kwargs): 74 Kernel.__init__(self, **kwargs) 75 self._ = None 76 self.executable = None 77 self.silent = False 78 79 def get_executable(self): 80 """Returns the mlir-opt executable path.""" 81 if not self.executable: 82 self.executable = _get_executable() 83 return self.executable 84 85 def process_output(self, output): 86 """Reports regular command output.""" 87 if not self.silent: 88 # Send standard output 89 stream_content = {"name": "stdout", "text": output} 90 self.send_response(self.iopub_socket, "stream", stream_content) 91 92 def process_error(self, output): 93 """Reports error response.""" 94 if not self.silent: 95 # Send standard error 96 stream_content = {"name": "stderr", "text": output} 97 self.send_response(self.iopub_socket, "stream", stream_content) 98 99 def do_execute( 100 self, code, silent, store_history=True, user_expressions=None, allow_stdin=False 101 ): 102 """Execute user code using mlir-opt binary.""" 103 104 def ok_status(): 105 """Returns OK status.""" 106 return { 107 "status": "ok", 108 "execution_count": self.execution_count, 109 "payload": [], 110 "user_expressions": {}, 111 } 112 113 def run(code): 114 """Run the code by pipeing via filesystem.""" 115 try: 116 inputmlir = tempfile.NamedTemporaryFile(delete=False) 117 command = [ 118 # Specify input and output file to error out if also 119 # set as arg. 120 self.get_executable(), 121 "--color", 122 inputmlir.name, 123 "-o", 124 "-", 125 ] 126 # Simple handling of repeating last line. 127 if code.endswith("\n_"): 128 if not self._: 129 raise NameError("No previous result set") 130 code = code[:-1] + self._ 131 inputmlir.write(code.encode("utf-8")) 132 inputmlir.close() 133 pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 134 output, errors = pipe.communicate() 135 exitcode = pipe.returncode 136 finally: 137 os.unlink(inputmlir.name) 138 139 # Replace temporary filename with placeholder. This takes the very 140 # remote chance where the full input filename (generated above) 141 # overlaps with something in the dump unrelated to the file. 142 fname = inputmlir.name.encode("utf-8") 143 output = output.replace(fname, b"<<input>>") 144 errors = errors.replace(fname, b"<<input>>") 145 return output, errors, exitcode 146 147 self.silent = silent 148 if not code.strip(): 149 return ok_status() 150 151 try: 152 output, errors, exitcode = run(code) 153 154 if exitcode: 155 self._ = None 156 else: 157 self._ = output.decode("utf-8") 158 except KeyboardInterrupt: 159 return {"status": "abort", "execution_count": self.execution_count} 160 except Exception as error: 161 # Print traceback for local debugging. 162 traceback.print_exc() 163 self._ = None 164 exitcode = 255 165 errors = repr(error).encode("utf-8") 166 167 if exitcode: 168 content = {"ename": "", "evalue": str(exitcode), "traceback": []} 169 170 self.send_response(self.iopub_socket, "error", content) 171 self.process_error(errors.decode("utf-8")) 172 173 content["execution_count"] = self.execution_count 174 content["status"] = "error" 175 return content 176 177 if not silent: 178 data = {} 179 data["text/x-mlir"] = self._ 180 content = { 181 "execution_count": self.execution_count, 182 "data": data, 183 "metadata": {}, 184 } 185 self.send_response(self.iopub_socket, "execute_result", content) 186 self.process_output(self._) 187 self.process_error(errors.decode("utf-8")) 188 return ok_status() 189