xref: /llvm-project/mlir/test/python/multithreaded_tests.py (revision f136c800b60dbfacdbb645e7e92acba52e2f279f)
1# RUN: %PYTHON %s
2"""
3This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
4Tests can be run using pytest:
5```bash
6python3.13t -mpytest -vvv multithreaded_tests.py
7```
8
9IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
10and passing if no warnings reported by TSAN and failing otherwise.
11
12
13Details on the generated tests and execution:
141) Multi-threaded execution: all generated tests are executed independently by
15a pool of threads, running each test multiple times, see @multi_threaded for details
16
172) Tests generation: we use existing tests: test/python/ir/*.py,
18test/python/dialects/*.py, etc to generate multi-threaded tests.
19In details, we perform the following:
20a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
21b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
22c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
23In order to import the test file as python module, we remove all executing functions, like
24`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
25
26
27Observed warnings reported by TSAN.
28
29CPython and free-threading known data-races:
301) ctypes related races: https://github.com/python/cpython/issues/127945
312) LLVM related data-races, llvm::raw_ostream is not thread-safe
32- mlir pass manager
33- dialects/transform_interpreter.py
34- ir/diagnostic_handler.py
35- ir/module.py
363) Dialect gpu module-to-binary method is unsafe
37"""
38import concurrent.futures
39import gc
40import importlib.util
41import os
42import sys
43import threading
44import tempfile
45import unittest
46
47from contextlib import contextmanager
48from functools import partial
49from pathlib import Path
50from typing import Optional, List
51
52import mlir.dialects.arith as arith
53from mlir.dialects import transform
54from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
55
56
57def import_from_path(module_name: str, file_path: Path):
58    spec = importlib.util.spec_from_file_location(module_name, file_path)
59    module = importlib.util.module_from_spec(spec)
60    sys.modules[module_name] = module
61    spec.loader.exec_module(module)
62    return module
63
64
65def copy_and_update(src_filepath: Path, dst_filepath: Path):
66    # We should remove all calls like `run(testMethod)`
67    with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
68        while True:
69            src_line = reader.readline()
70            if len(src_line) == 0:
71                break
72            skip_lines = [
73                "run(",
74                "@run",
75                "@constructAndPrintInModule",
76                "run_apply_patterns(",
77                "@run_apply_patterns",
78                "@test_in_context",
79                "@construct_and_print_in_module",
80            ]
81            if any(src_line.startswith(line) for line in skip_lines):
82                continue
83            writer.write(src_line)
84
85
86# Helper run functions
87def run(f):
88    f()
89
90
91def run_with_context_and_location(f):
92    print("\nTEST:", f.__name__)
93    with Context(), Location.unknown():
94        f()
95    return f
96
97
98def run_with_insertion_point(f):
99    print("\nTEST:", f.__name__)
100    with Context() as ctx, Location.unknown():
101        module = Module.create()
102        with InsertionPoint(module.body):
103            f(ctx)
104        print(module)
105
106
107def run_with_insertion_point_v2(f):
108    print("\nTEST:", f.__name__)
109    with Context(), Location.unknown():
110        module = Module.create()
111        with InsertionPoint(module.body):
112            f()
113        print(module)
114    return f
115
116
117def run_with_insertion_point_v3(f):
118    with Context(), Location.unknown():
119        module = Module.create()
120        with InsertionPoint(module.body):
121            print("\nTEST:", f.__name__)
122            f(module)
123        print(module)
124    return f
125
126
127def run_with_insertion_point_v4(f):
128    print("\nTEST:", f.__name__)
129    with Context() as ctx, Location.unknown():
130        ctx.allow_unregistered_dialects = True
131        module = Module.create()
132        with InsertionPoint(module.body):
133            f()
134    return f
135
136
137def run_apply_patterns(f):
138    with Context(), Location.unknown():
139        module = Module.create()
140        with InsertionPoint(module.body):
141            sequence = transform.SequenceOp(
142                transform.FailurePropagationMode.Propagate,
143                [],
144                transform.AnyOpType.get(),
145            )
146            with InsertionPoint(sequence.body):
147                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
148                with InsertionPoint(apply.patterns):
149                    f()
150                transform.YieldOp()
151        print("\nTEST:", f.__name__)
152        print(module)
153    return f
154
155
156def run_transform_tensor_ext(f):
157    print("\nTEST:", f.__name__)
158    with Context(), Location.unknown():
159        module = Module.create()
160        with InsertionPoint(module.body):
161            sequence = transform.SequenceOp(
162                transform.FailurePropagationMode.Propagate,
163                [],
164                transform.AnyOpType.get(),
165            )
166            with InsertionPoint(sequence.body):
167                f(sequence.bodyTarget)
168                transform.YieldOp()
169        print(module)
170    return f
171
172
173def run_transform_structured_ext(f):
174    with Context(), Location.unknown():
175        module = Module.create()
176        with InsertionPoint(module.body):
177            print("\nTEST:", f.__name__)
178            f()
179        module.operation.verify()
180        print(module)
181    return f
182
183
184def run_construct_and_print_in_module(f):
185    print("\nTEST:", f.__name__)
186    with Context(), Location.unknown():
187        module = Module.create()
188        with InsertionPoint(module.body):
189            module = f(module)
190        if module is not None:
191            print(module)
192    return f
193
194
195TEST_MODULES = [
196    ("execution_engine", run),
197    ("pass_manager", run),
198    ("dialects/affine", run_with_insertion_point_v2),
199    ("dialects/func", run_with_insertion_point_v2),
200    ("dialects/arith_dialect", run),
201    ("dialects/arith_llvm", run),
202    ("dialects/async_dialect", run),
203    ("dialects/builtin", run),
204    ("dialects/cf", run_with_insertion_point_v4),
205    ("dialects/complex_dialect", run),
206    ("dialects/func", run_with_insertion_point_v2),
207    ("dialects/index_dialect", run_with_insertion_point),
208    ("dialects/llvm", run_with_insertion_point_v2),
209    ("dialects/math_dialect", run),
210    ("dialects/memref", run),
211    ("dialects/ml_program", run_with_insertion_point_v2),
212    ("dialects/nvgpu", run_with_insertion_point_v2),
213    ("dialects/nvvm", run_with_insertion_point_v2),
214    ("dialects/ods_helpers", run),
215    ("dialects/openmp_ops", run_with_insertion_point_v2),
216    ("dialects/pdl_ops", run_with_insertion_point_v2),
217    # ("dialects/python_test", run),  # TODO: Need to pass pybind11 or nanobind argv
218    ("dialects/quant", run),
219    ("dialects/rocdl", run_with_insertion_point_v2),
220    ("dialects/scf", run_with_insertion_point_v2),
221    ("dialects/shape", run),
222    ("dialects/spirv_dialect", run),
223    ("dialects/tensor", run),
224    # ("dialects/tosa", ),  # Nothing to test
225    ("dialects/transform_bufferization_ext", run_with_insertion_point_v2),
226    # ("dialects/transform_extras", ),  # Needs a more complicated execution schema
227    ("dialects/transform_gpu_ext", run_transform_tensor_ext),
228    (
229        "dialects/transform_interpreter",
230        run_with_context_and_location,
231        ["print_", "transform_options", "failed", "include"],
232    ),
233    (
234        "dialects/transform_loop_ext",
235        run_with_insertion_point_v2,
236        ["loopOutline"],
237    ),
238    ("dialects/transform_memref_ext", run_with_insertion_point_v2),
239    ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2),
240    ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
241    ("dialects/transform_structured_ext", run_transform_structured_ext),
242    ("dialects/transform_tensor_ext", run_transform_tensor_ext),
243    (
244        "dialects/transform_vector_ext",
245        run_apply_patterns,
246        ["configurable_patterns"],
247    ),
248    ("dialects/transform", run_with_insertion_point_v3),
249    ("dialects/vector", run_with_context_and_location),
250    ("dialects/gpu/dialect", run_with_context_and_location),
251    ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
252    ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
253    ("dialects/linalg/ops", run),
254    # TO ADD: No proper tests in this dialects/linalg/opsdsl/*
255    # ("dialects/linalg/opsdsl/*", ...),
256    ("dialects/sparse_tensor/dialect", run),
257    ("dialects/sparse_tensor/passes", run),
258    ("integration/dialects/pdl", run_construct_and_print_in_module),
259    ("integration/dialects/transform", run_construct_and_print_in_module),
260    ("integration/dialects/linalg/opsrun", run),
261    ("ir/affine_expr", run),
262    ("ir/affine_map", run),
263    ("ir/array_attributes", run),
264    ("ir/attributes", run),
265    ("ir/blocks", run),
266    ("ir/builtin_types", run),
267    ("ir/context_managers", run),
268    ("ir/debug", run),
269    ("ir/diagnostic_handler", run),
270    ("ir/dialects", run),
271    ("ir/exception", run),
272    ("ir/insertion_point", run),
273    ("ir/integer_set", run),
274    ("ir/location", run),
275    ("ir/module", run),
276    ("ir/operation", run),
277    ("ir/symbol_table", run),
278    ("ir/value", run),
279]
280
281TESTS_TO_SKIP = [
282    "test_execution_engine__testNanoTime_multi_threaded",  # testNanoTime can't run in multiple threads, even with GIL
283    "test_execution_engine__testSharedLibLoad_multi_threaded",  # testSharedLibLoad can't run in multiple threads, even with GIL
284    "test_dialects_arith_dialect__testArithValue_multi_threaded",  # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
285    "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded",  # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
286    "test_ir_value__testValueCasters_multi_threaded",  # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
287    # tests indirectly calling thread-unsafe llvm::raw_ostream
288    "test_execution_engine__testInvalidModule_multi_threaded",  # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
289    "test_pass_manager__testPrintIrAfterAll_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
290    "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded",  # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
291    "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
292    "test_pass_manager__testPrintIrTree_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
293    "test_pass_manager__testRunPipeline_multi_threaded",  # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
294    "test_dialects_transform_interpreter__include_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
295    "test_dialects_transform_interpreter__transform_options_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
296    "test_dialects_transform_interpreter__print_self_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
297    "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded",  # mlirEmitError calls thread-unsafe llvm::raw_ostream
298    "test_ir_module__testParseSuccess_multi_threaded",  # mlirOperationDump calls thread-unsafe llvm::raw_ostream
299    # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
300    # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
301    "test_execution_engine__testCapsule_multi_threaded",
302    "test_execution_engine__testDumpToObjectFile_multi_threaded",
303]
304
305TESTS_TO_XFAIL = [
306    # execution_engine tests:
307    # - ctypes related data-races: https://github.com/python/cpython/issues/127945
308    "test_execution_engine__testBF16Memref_multi_threaded",
309    "test_execution_engine__testBasicCallback_multi_threaded",
310    "test_execution_engine__testComplexMemrefAdd_multi_threaded",
311    "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
312    "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
313    "test_execution_engine__testF16MemrefAdd_multi_threaded",
314    "test_execution_engine__testF8E5M2Memref_multi_threaded",
315    "test_execution_engine__testInvokeFloatAdd_multi_threaded",
316    "test_execution_engine__testInvokeVoid_multi_threaded",  # a ctypes race
317    "test_execution_engine__testMemrefAdd_multi_threaded",
318    "test_execution_engine__testRankedMemRefCallback_multi_threaded",
319    "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
320    "test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
321    "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
322    # dialects tests
323    "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded",  # Related to ctypes data races
324    "test_dialects_transform_interpreter__print_other_multi_threaded",  # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
325    "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded",  # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
326    "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
327    "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
328    "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
329    # integration tests
330    "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded",  # Related to ctypes data races
331    "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded",  # Related to ctypes data races
332    "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded",  # ctypes
333    "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded",  # ctypes
334    "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded",  # ctypes
335    "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded",  # ctypes
336    "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded",  # ctypes
337    "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded",  # ctypes
338    "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded",  # ctypes
339    "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded",  # ctypes
340]
341
342
343def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
344    def decorator(test_cls):
345        this_folder = Path(__file__).parent.absolute()
346        test_cls.output_folder = tempfile.TemporaryDirectory()
347        output_folder = Path(test_cls.output_folder.name)
348
349        for test_mod_info in test_modules:
350            assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
351            if len(test_mod_info) == 2:
352                test_module_name, exec_fn = test_mod_info
353                test_pattern = None
354            else:
355                test_module_name, exec_fn, test_pattern = test_mod_info
356
357            src_filepath = this_folder / f"{test_module_name}.py"
358            dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
359            if not dst_filepath.parent.exists():
360                dst_filepath.parent.mkdir(parents=True)
361            copy_and_update(src_filepath, dst_filepath)
362            test_mod = import_from_path(test_module_name, dst_filepath)
363            for attr_name in dir(test_mod):
364                is_test_fn = test_pattern is None and attr_name.startswith("test")
365                is_test_fn |= test_pattern is not None and any(
366                    [p in attr_name for p in test_pattern]
367                )
368                if is_test_fn:
369                    obj = getattr(test_mod, attr_name)
370                    if callable(obj):
371                        test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
372
373                        def wrapped_test_fn(
374                            self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
375                        ):
376                            __exec_fn__(__test_fn__)
377
378                        setattr(test_cls, test_name, wrapped_test_fn)
379        return test_cls
380
381    return decorator
382
383
384@contextmanager
385def _capture_output(fp):
386    # Inspired from jax test_utils.py capture_stderr method
387    # ``None`` means nothing has not been captured yet.
388    captured = None
389
390    def get_output() -> str:
391        if captured is None:
392            raise ValueError("get_output() called while the context is active.")
393        return captured
394
395    with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
396        original_fd = os.dup(fp.fileno())
397        os.dup2(f.fileno(), fp.fileno())
398        try:
399            yield get_output
400        finally:
401            # Python also has its own buffers, make sure everything is flushed.
402            fp.flush()
403            os.fsync(fp.fileno())
404            f.seek(0)
405            captured = f.read()
406            os.dup2(original_fd, fp.fileno())
407
408
409capture_stdout = partial(_capture_output, sys.stdout)
410capture_stderr = partial(_capture_output, sys.stderr)
411
412
413def multi_threaded(
414    num_workers: int,
415    num_runs: int = 5,
416    skip_tests: Optional[List[str]] = None,
417    xfail_tests: Optional[List[str]] = None,
418    test_prefix: str = "_original_test",
419    multithreaded_test_postfix: str = "_multi_threaded",
420):
421    """Decorator that runs a test in a multi-threaded environment."""
422
423    def decorator(test_cls):
424        for name, test_fn in test_cls.__dict__.copy().items():
425            if not (name.startswith(test_prefix) and callable(test_fn)):
426                continue
427
428            name = f"test{name[len(test_prefix):]}"
429            if skip_tests is not None:
430                if any(
431                    test_name.replace(multithreaded_test_postfix, "") in name
432                    for test_name in skip_tests
433                ):
434                    continue
435
436            def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
437                with capture_stdout(), capture_stderr() as get_output:
438                    barrier = threading.Barrier(num_workers)
439
440                    def closure():
441                        barrier.wait()
442                        for _ in range(num_runs):
443                            __test_fn__(self, *args, **kwargs)
444
445                    with concurrent.futures.ThreadPoolExecutor(
446                        max_workers=num_workers
447                    ) as executor:
448                        futures = []
449                        for _ in range(num_workers):
450                            futures.append(executor.submit(closure))
451                        # We should call future.result() to re-raise an exception if test has
452                        # failed
453                        assert len(list(f.result() for f in futures)) == num_workers
454
455                    gc.collect()
456                    assert Context._get_live_count() == 0
457
458                captured = get_output()
459                if len(captured) > 0 and "ThreadSanitizer" in captured:
460                    raise RuntimeError(
461                        f"ThreadSanitizer reported warnings:\n{captured}"
462                    )
463
464            test_new_name = f"{name}{multithreaded_test_postfix}"
465            if xfail_tests is not None and test_new_name in xfail_tests:
466                multi_threaded_test_fn = unittest.expectedFailure(
467                    multi_threaded_test_fn
468                )
469
470            setattr(test_cls, test_new_name, multi_threaded_test_fn)
471
472        return test_cls
473
474    return decorator
475
476
477@multi_threaded(
478    num_workers=10,
479    num_runs=20,
480    skip_tests=TESTS_TO_SKIP,
481    xfail_tests=TESTS_TO_XFAIL,
482)
483@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
484class TestAllMultiThreaded(unittest.TestCase):
485    @classmethod
486    def tearDownClass(cls):
487        if hasattr(cls, "output_folder"):
488            cls.output_folder.cleanup()
489
490    def _original_test_create_context(self):
491        with Context() as ctx:
492            print(ctx._get_live_count())
493            print(ctx._get_live_module_count())
494            print(ctx._get_live_operation_count())
495            print(ctx._get_live_operation_objects())
496            print(ctx._get_context_again() is ctx)
497            print(ctx._clear_live_operations())
498
499    def _original_test_create_module_with_consts(self):
500        py_values = [123, 234, 345]
501        with Context() as ctx:
502            module = Module.create(loc=Location.file("foo.txt", 0, 0))
503
504            dtype = IntegerType.get_signless(64)
505            with InsertionPoint(module.body), Location.name("a"):
506                arith.constant(dtype, py_values[0])
507
508            with InsertionPoint(module.body), Location.name("b"):
509                arith.constant(dtype, py_values[1])
510
511            with InsertionPoint(module.body), Location.name("c"):
512                arith.constant(dtype, py_values[2])
513
514
515if __name__ == "__main__":
516    # Do not run the tests on CPython with GIL
517    if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
518        unittest.main()
519