xref: /llvm-project/mlir/python/mlir/_mlir_libs/__init__.py (revision c703b4645c79e889fd6a0f3f64f01f957d981aa4)
1# Licensed under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from typing import Any, Sequence
6
7import os
8
9_this_dir = os.path.dirname(__file__)
10
11
12def get_lib_dirs() -> Sequence[str]:
13    """Gets the lib directory for linking to shared libraries.
14
15    On some platforms, the package may need to be built specially to export
16    development libraries.
17    """
18    return [_this_dir]
19
20
21def get_include_dirs() -> Sequence[str]:
22    """Gets the include directory for compiling against exported C libraries.
23
24    Depending on how the package was build, development C libraries may or may
25    not be present.
26    """
27    return [os.path.join(_this_dir, "include")]
28
29
30# Perform Python level site initialization. This involves:
31#   1. Attempting to load initializer modules, specific to the distribution.
32#   2. Defining the concrete mlir.ir.Context that does site specific
33#      initialization.
34#
35# Aside from just being far more convenient to do this at the Python level,
36# it is actually quite hard/impossible to have such __init__ hooks, given
37# the pybind memory model (i.e. there is not a Python reference to the object
38# in the scope of the base class __init__).
39#
40# For #1, we:
41#   a. Probe for modules named '_mlirRegisterEverything' and
42#     '_site_initialize_{i}', where 'i' is a number starting at zero and
43#     proceeding so long as a module with the name is found.
44#   b. If the module has a 'register_dialects' attribute, it will be called
45#     immediately with a DialectRegistry to populate.
46#   c. If the module has a 'context_init_hook', it will be added to a list
47#     of callbacks that are invoked as the last step of Context
48#     initialization (and passed the Context under construction).
49#   d. If the module has a 'disable_multithreading' attribute, it will be
50#     taken as a boolean. If it is True for any initializer, then the
51#     default behavior of enabling multithreading on the context
52#     will be suppressed. This complies with the original behavior of all
53#     contexts being created with multithreading enabled while allowing
54#     this behavior to be changed if needed (i.e. if a context_init_hook
55#     explicitly sets up multithreading).
56#
57# This facility allows downstreams to customize Context creation to their
58# needs.
59
60_dialect_registry = None
61_load_on_create_dialects = None
62
63
64def get_dialect_registry():
65    global _dialect_registry
66
67    if _dialect_registry is None:
68        from ._mlir import ir
69
70        _dialect_registry = ir.DialectRegistry()
71
72    return _dialect_registry
73
74
75def append_load_on_create_dialect(dialect: str):
76    global _load_on_create_dialects
77    if _load_on_create_dialects is None:
78        _load_on_create_dialects = [dialect]
79    else:
80        _load_on_create_dialects.append(dialect)
81
82
83def get_load_on_create_dialects():
84    global _load_on_create_dialects
85    if _load_on_create_dialects is None:
86        _load_on_create_dialects = []
87    return _load_on_create_dialects
88
89
90def _site_initialize():
91    import importlib
92    import itertools
93    import logging
94    from ._mlir import ir
95
96    logger = logging.getLogger(__name__)
97    post_init_hooks = []
98    disable_multithreading = False
99    # This flag disables eagerly loading all dialects. Eagerly loading is often
100    # not the desired behavior (see
101    # https://github.com/llvm/llvm-project/issues/56037), and the logic is that
102    # if any module has this attribute set, then we don't load all (e.g., it's
103    # being used in a solution where the loading is controlled).
104    disable_load_all_available_dialects = False
105
106    def process_initializer_module(module_name):
107        nonlocal disable_multithreading
108        nonlocal disable_load_all_available_dialects
109        try:
110            m = importlib.import_module(f".{module_name}", __name__)
111        except ModuleNotFoundError:
112            return False
113        except ImportError:
114            message = (
115                f"Error importing mlir initializer {module_name}. This may "
116                "happen in unclean incremental builds but is likely a real bug if "
117                "encountered otherwise and the MLIR Python API may not function."
118            )
119            logger.warning(message, exc_info=True)
120            return False
121
122        logger.debug("Initializing MLIR with module: %s", module_name)
123        if hasattr(m, "register_dialects"):
124            logger.debug("Registering dialects from initializer %r", m)
125            m.register_dialects(get_dialect_registry())
126        if hasattr(m, "context_init_hook"):
127            logger.debug("Adding context init hook from %r", m)
128            post_init_hooks.append(m.context_init_hook)
129        if hasattr(m, "disable_multithreading"):
130            if bool(m.disable_multithreading):
131                logger.debug("Disabling multi-threading for context")
132                disable_multithreading = True
133        if hasattr(m, "disable_load_all_available_dialects"):
134            disable_load_all_available_dialects = True
135        return True
136
137    # If _mlirRegisterEverything is built, then include it as an initializer
138    # module.
139    init_module = None
140    if process_initializer_module("_mlirRegisterEverything"):
141        init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
142
143    # Load all _site_initialize_{i} modules, where 'i' is a number starting
144    # at 0.
145    for i in itertools.count():
146        module_name = f"_site_initialize_{i}"
147        if not process_initializer_module(module_name):
148            break
149
150    class Context(ir._BaseContext):
151        def __init__(self, load_on_create_dialects=None, *args, **kwargs):
152            super().__init__(*args, **kwargs)
153            self.append_dialect_registry(get_dialect_registry())
154            for hook in post_init_hooks:
155                hook(self)
156            if not disable_multithreading:
157                self.enable_multithreading(True)
158            if load_on_create_dialects is not None:
159                logger.debug(
160                    "Loading all dialects from load_on_create_dialects arg %r",
161                    load_on_create_dialects,
162                )
163                for dialect in load_on_create_dialects:
164                    # This triggers loading the dialect into the context.
165                    _ = self.dialects[dialect]
166            else:
167                if disable_load_all_available_dialects:
168                    dialects = get_load_on_create_dialects()
169                    if dialects:
170                        logger.debug(
171                            "Loading all dialects from global load_on_create_dialects %r",
172                            dialects,
173                        )
174                        for dialect in dialects:
175                            # This triggers loading the dialect into the context.
176                            _ = self.dialects[dialect]
177                else:
178                    logger.debug("Loading all available dialects")
179                    self.load_all_available_dialects()
180            if init_module:
181                logger.debug(
182                    "Registering translations from initializer %r", init_module
183                )
184                init_module.register_llvm_translations(self)
185
186    ir.Context = Context
187
188    class MLIRError(Exception):
189        """
190        An exception with diagnostic information. Has the following fields:
191          message: str
192          error_diagnostics: List[ir.DiagnosticInfo]
193        """
194
195        def __init__(self, message, error_diagnostics):
196            self.message = message
197            self.error_diagnostics = error_diagnostics
198            super().__init__(message, error_diagnostics)
199
200        def __str__(self):
201            s = self.message
202            if self.error_diagnostics:
203                s += ":"
204            for diag in self.error_diagnostics:
205                s += (
206                    "\nerror: "
207                    + str(diag.location)[4:-1]
208                    + ": "
209                    + diag.message.replace("\n", "\n  ")
210                )
211                for note in diag.notes:
212                    s += (
213                        "\n note: "
214                        + str(note.location)[4:-1]
215                        + ": "
216                        + note.message.replace("\n", "\n  ")
217                    )
218            return s
219
220    ir.MLIRError = MLIRError
221
222
223_site_initialize()
224