1# Licensed under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from typing import Any, Sequence 6 7import os 8 9_this_dir = os.path.dirname(__file__) 10 11 12def get_lib_dirs() -> Sequence[str]: 13 """Gets the lib directory for linking to shared libraries. 14 15 On some platforms, the package may need to be built specially to export 16 development libraries. 17 """ 18 return [_this_dir] 19 20 21def get_include_dirs() -> Sequence[str]: 22 """Gets the include directory for compiling against exported C libraries. 23 24 Depending on how the package was build, development C libraries may or may 25 not be present. 26 """ 27 return [os.path.join(_this_dir, "include")] 28 29 30# Perform Python level site initialization. This involves: 31# 1. Attempting to load initializer modules, specific to the distribution. 32# 2. Defining the concrete mlir.ir.Context that does site specific 33# initialization. 34# 35# Aside from just being far more convenient to do this at the Python level, 36# it is actually quite hard/impossible to have such __init__ hooks, given 37# the pybind memory model (i.e. there is not a Python reference to the object 38# in the scope of the base class __init__). 39# 40# For #1, we: 41# a. Probe for modules named '_mlirRegisterEverything' and 42# '_site_initialize_{i}', where 'i' is a number starting at zero and 43# proceeding so long as a module with the name is found. 44# b. If the module has a 'register_dialects' attribute, it will be called 45# immediately with a DialectRegistry to populate. 46# c. If the module has a 'context_init_hook', it will be added to a list 47# of callbacks that are invoked as the last step of Context 48# initialization (and passed the Context under construction). 49# d. If the module has a 'disable_multithreading' attribute, it will be 50# taken as a boolean. If it is True for any initializer, then the 51# default behavior of enabling multithreading on the context 52# will be suppressed. This complies with the original behavior of all 53# contexts being created with multithreading enabled while allowing 54# this behavior to be changed if needed (i.e. if a context_init_hook 55# explicitly sets up multithreading). 56# 57# This facility allows downstreams to customize Context creation to their 58# needs. 59 60_dialect_registry = None 61_load_on_create_dialects = None 62 63 64def get_dialect_registry(): 65 global _dialect_registry 66 67 if _dialect_registry is None: 68 from ._mlir import ir 69 70 _dialect_registry = ir.DialectRegistry() 71 72 return _dialect_registry 73 74 75def append_load_on_create_dialect(dialect: str): 76 global _load_on_create_dialects 77 if _load_on_create_dialects is None: 78 _load_on_create_dialects = [dialect] 79 else: 80 _load_on_create_dialects.append(dialect) 81 82 83def get_load_on_create_dialects(): 84 global _load_on_create_dialects 85 if _load_on_create_dialects is None: 86 _load_on_create_dialects = [] 87 return _load_on_create_dialects 88 89 90def _site_initialize(): 91 import importlib 92 import itertools 93 import logging 94 from ._mlir import ir 95 96 logger = logging.getLogger(__name__) 97 post_init_hooks = [] 98 disable_multithreading = False 99 # This flag disables eagerly loading all dialects. Eagerly loading is often 100 # not the desired behavior (see 101 # https://github.com/llvm/llvm-project/issues/56037), and the logic is that 102 # if any module has this attribute set, then we don't load all (e.g., it's 103 # being used in a solution where the loading is controlled). 104 disable_load_all_available_dialects = False 105 106 def process_initializer_module(module_name): 107 nonlocal disable_multithreading 108 nonlocal disable_load_all_available_dialects 109 try: 110 m = importlib.import_module(f".{module_name}", __name__) 111 except ModuleNotFoundError: 112 return False 113 except ImportError: 114 message = ( 115 f"Error importing mlir initializer {module_name}. This may " 116 "happen in unclean incremental builds but is likely a real bug if " 117 "encountered otherwise and the MLIR Python API may not function." 118 ) 119 logger.warning(message, exc_info=True) 120 return False 121 122 logger.debug("Initializing MLIR with module: %s", module_name) 123 if hasattr(m, "register_dialects"): 124 logger.debug("Registering dialects from initializer %r", m) 125 m.register_dialects(get_dialect_registry()) 126 if hasattr(m, "context_init_hook"): 127 logger.debug("Adding context init hook from %r", m) 128 post_init_hooks.append(m.context_init_hook) 129 if hasattr(m, "disable_multithreading"): 130 if bool(m.disable_multithreading): 131 logger.debug("Disabling multi-threading for context") 132 disable_multithreading = True 133 if hasattr(m, "disable_load_all_available_dialects"): 134 disable_load_all_available_dialects = True 135 return True 136 137 # If _mlirRegisterEverything is built, then include it as an initializer 138 # module. 139 init_module = None 140 if process_initializer_module("_mlirRegisterEverything"): 141 init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) 142 143 # Load all _site_initialize_{i} modules, where 'i' is a number starting 144 # at 0. 145 for i in itertools.count(): 146 module_name = f"_site_initialize_{i}" 147 if not process_initializer_module(module_name): 148 break 149 150 class Context(ir._BaseContext): 151 def __init__(self, load_on_create_dialects=None, *args, **kwargs): 152 super().__init__(*args, **kwargs) 153 self.append_dialect_registry(get_dialect_registry()) 154 for hook in post_init_hooks: 155 hook(self) 156 if not disable_multithreading: 157 self.enable_multithreading(True) 158 if load_on_create_dialects is not None: 159 logger.debug( 160 "Loading all dialects from load_on_create_dialects arg %r", 161 load_on_create_dialects, 162 ) 163 for dialect in load_on_create_dialects: 164 # This triggers loading the dialect into the context. 165 _ = self.dialects[dialect] 166 else: 167 if disable_load_all_available_dialects: 168 dialects = get_load_on_create_dialects() 169 if dialects: 170 logger.debug( 171 "Loading all dialects from global load_on_create_dialects %r", 172 dialects, 173 ) 174 for dialect in dialects: 175 # This triggers loading the dialect into the context. 176 _ = self.dialects[dialect] 177 else: 178 logger.debug("Loading all available dialects") 179 self.load_all_available_dialects() 180 if init_module: 181 logger.debug( 182 "Registering translations from initializer %r", init_module 183 ) 184 init_module.register_llvm_translations(self) 185 186 ir.Context = Context 187 188 class MLIRError(Exception): 189 """ 190 An exception with diagnostic information. Has the following fields: 191 message: str 192 error_diagnostics: List[ir.DiagnosticInfo] 193 """ 194 195 def __init__(self, message, error_diagnostics): 196 self.message = message 197 self.error_diagnostics = error_diagnostics 198 super().__init__(message, error_diagnostics) 199 200 def __str__(self): 201 s = self.message 202 if self.error_diagnostics: 203 s += ":" 204 for diag in self.error_diagnostics: 205 s += ( 206 "\nerror: " 207 + str(diag.location)[4:-1] 208 + ": " 209 + diag.message.replace("\n", "\n ") 210 ) 211 for note in diag.notes: 212 s += ( 213 "\n note: " 214 + str(note.location)[4:-1] 215 + ": " 216 + note.message.replace("\n", "\n ") 217 ) 218 return s 219 220 ir.MLIRError = MLIRError 221 222 223_site_initialize() 224