xref: /llvm-project/mlir/test/python/multithreaded_tests.py (revision f136c800b60dbfacdbb645e7e92acba52e2f279f)
1*f136c800Svfdev# RUN: %PYTHON %s
2*f136c800Svfdev"""
3*f136c800SvfdevThis script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
4*f136c800SvfdevTests can be run using pytest:
5*f136c800Svfdev```bash
6*f136c800Svfdevpython3.13t -mpytest -vvv multithreaded_tests.py
7*f136c800Svfdev```
8*f136c800Svfdev
9*f136c800SvfdevIMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
10*f136c800Svfdevand passing if no warnings reported by TSAN and failing otherwise.
11*f136c800Svfdev
12*f136c800Svfdev
13*f136c800SvfdevDetails on the generated tests and execution:
14*f136c800Svfdev1) Multi-threaded execution: all generated tests are executed independently by
15*f136c800Svfdeva pool of threads, running each test multiple times, see @multi_threaded for details
16*f136c800Svfdev
17*f136c800Svfdev2) Tests generation: we use existing tests: test/python/ir/*.py,
18*f136c800Svfdevtest/python/dialects/*.py, etc to generate multi-threaded tests.
19*f136c800SvfdevIn details, we perform the following:
20*f136c800Svfdeva) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
21*f136c800Svfdevb) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
22*f136c800Svfdevc) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
23*f136c800SvfdevIn order to import the test file as python module, we remove all executing functions, like
24*f136c800Svfdev`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
25*f136c800Svfdev
26*f136c800Svfdev
27*f136c800SvfdevObserved warnings reported by TSAN.
28*f136c800Svfdev
29*f136c800SvfdevCPython and free-threading known data-races:
30*f136c800Svfdev1) ctypes related races: https://github.com/python/cpython/issues/127945
31*f136c800Svfdev2) LLVM related data-races, llvm::raw_ostream is not thread-safe
32*f136c800Svfdev- mlir pass manager
33*f136c800Svfdev- dialects/transform_interpreter.py
34*f136c800Svfdev- ir/diagnostic_handler.py
35*f136c800Svfdev- ir/module.py
36*f136c800Svfdev3) Dialect gpu module-to-binary method is unsafe
37*f136c800Svfdev"""
38*f136c800Svfdevimport concurrent.futures
39*f136c800Svfdevimport gc
40*f136c800Svfdevimport importlib.util
41*f136c800Svfdevimport os
42*f136c800Svfdevimport sys
43*f136c800Svfdevimport threading
44*f136c800Svfdevimport tempfile
45*f136c800Svfdevimport unittest
46*f136c800Svfdev
47*f136c800Svfdevfrom contextlib import contextmanager
48*f136c800Svfdevfrom functools import partial
49*f136c800Svfdevfrom pathlib import Path
50*f136c800Svfdevfrom typing import Optional, List
51*f136c800Svfdev
52*f136c800Svfdevimport mlir.dialects.arith as arith
53*f136c800Svfdevfrom mlir.dialects import transform
54*f136c800Svfdevfrom mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
55*f136c800Svfdev
56*f136c800Svfdev
57*f136c800Svfdevdef import_from_path(module_name: str, file_path: Path):
58*f136c800Svfdev    spec = importlib.util.spec_from_file_location(module_name, file_path)
59*f136c800Svfdev    module = importlib.util.module_from_spec(spec)
60*f136c800Svfdev    sys.modules[module_name] = module
61*f136c800Svfdev    spec.loader.exec_module(module)
62*f136c800Svfdev    return module
63*f136c800Svfdev
64*f136c800Svfdev
65*f136c800Svfdevdef copy_and_update(src_filepath: Path, dst_filepath: Path):
66*f136c800Svfdev    # We should remove all calls like `run(testMethod)`
67*f136c800Svfdev    with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
68*f136c800Svfdev        while True:
69*f136c800Svfdev            src_line = reader.readline()
70*f136c800Svfdev            if len(src_line) == 0:
71*f136c800Svfdev                break
72*f136c800Svfdev            skip_lines = [
73*f136c800Svfdev                "run(",
74*f136c800Svfdev                "@run",
75*f136c800Svfdev                "@constructAndPrintInModule",
76*f136c800Svfdev                "run_apply_patterns(",
77*f136c800Svfdev                "@run_apply_patterns",
78*f136c800Svfdev                "@test_in_context",
79*f136c800Svfdev                "@construct_and_print_in_module",
80*f136c800Svfdev            ]
81*f136c800Svfdev            if any(src_line.startswith(line) for line in skip_lines):
82*f136c800Svfdev                continue
83*f136c800Svfdev            writer.write(src_line)
84*f136c800Svfdev
85*f136c800Svfdev
86*f136c800Svfdev# Helper run functions
87*f136c800Svfdevdef run(f):
88*f136c800Svfdev    f()
89*f136c800Svfdev
90*f136c800Svfdev
91*f136c800Svfdevdef run_with_context_and_location(f):
92*f136c800Svfdev    print("\nTEST:", f.__name__)
93*f136c800Svfdev    with Context(), Location.unknown():
94*f136c800Svfdev        f()
95*f136c800Svfdev    return f
96*f136c800Svfdev
97*f136c800Svfdev
98*f136c800Svfdevdef run_with_insertion_point(f):
99*f136c800Svfdev    print("\nTEST:", f.__name__)
100*f136c800Svfdev    with Context() as ctx, Location.unknown():
101*f136c800Svfdev        module = Module.create()
102*f136c800Svfdev        with InsertionPoint(module.body):
103*f136c800Svfdev            f(ctx)
104*f136c800Svfdev        print(module)
105*f136c800Svfdev
106*f136c800Svfdev
107*f136c800Svfdevdef run_with_insertion_point_v2(f):
108*f136c800Svfdev    print("\nTEST:", f.__name__)
109*f136c800Svfdev    with Context(), Location.unknown():
110*f136c800Svfdev        module = Module.create()
111*f136c800Svfdev        with InsertionPoint(module.body):
112*f136c800Svfdev            f()
113*f136c800Svfdev        print(module)
114*f136c800Svfdev    return f
115*f136c800Svfdev
116*f136c800Svfdev
117*f136c800Svfdevdef run_with_insertion_point_v3(f):
118*f136c800Svfdev    with Context(), Location.unknown():
119*f136c800Svfdev        module = Module.create()
120*f136c800Svfdev        with InsertionPoint(module.body):
121*f136c800Svfdev            print("\nTEST:", f.__name__)
122*f136c800Svfdev            f(module)
123*f136c800Svfdev        print(module)
124*f136c800Svfdev    return f
125*f136c800Svfdev
126*f136c800Svfdev
127*f136c800Svfdevdef run_with_insertion_point_v4(f):
128*f136c800Svfdev    print("\nTEST:", f.__name__)
129*f136c800Svfdev    with Context() as ctx, Location.unknown():
130*f136c800Svfdev        ctx.allow_unregistered_dialects = True
131*f136c800Svfdev        module = Module.create()
132*f136c800Svfdev        with InsertionPoint(module.body):
133*f136c800Svfdev            f()
134*f136c800Svfdev    return f
135*f136c800Svfdev
136*f136c800Svfdev
137*f136c800Svfdevdef run_apply_patterns(f):
138*f136c800Svfdev    with Context(), Location.unknown():
139*f136c800Svfdev        module = Module.create()
140*f136c800Svfdev        with InsertionPoint(module.body):
141*f136c800Svfdev            sequence = transform.SequenceOp(
142*f136c800Svfdev                transform.FailurePropagationMode.Propagate,
143*f136c800Svfdev                [],
144*f136c800Svfdev                transform.AnyOpType.get(),
145*f136c800Svfdev            )
146*f136c800Svfdev            with InsertionPoint(sequence.body):
147*f136c800Svfdev                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
148*f136c800Svfdev                with InsertionPoint(apply.patterns):
149*f136c800Svfdev                    f()
150*f136c800Svfdev                transform.YieldOp()
151*f136c800Svfdev        print("\nTEST:", f.__name__)
152*f136c800Svfdev        print(module)
153*f136c800Svfdev    return f
154*f136c800Svfdev
155*f136c800Svfdev
156*f136c800Svfdevdef run_transform_tensor_ext(f):
157*f136c800Svfdev    print("\nTEST:", f.__name__)
158*f136c800Svfdev    with Context(), Location.unknown():
159*f136c800Svfdev        module = Module.create()
160*f136c800Svfdev        with InsertionPoint(module.body):
161*f136c800Svfdev            sequence = transform.SequenceOp(
162*f136c800Svfdev                transform.FailurePropagationMode.Propagate,
163*f136c800Svfdev                [],
164*f136c800Svfdev                transform.AnyOpType.get(),
165*f136c800Svfdev            )
166*f136c800Svfdev            with InsertionPoint(sequence.body):
167*f136c800Svfdev                f(sequence.bodyTarget)
168*f136c800Svfdev                transform.YieldOp()
169*f136c800Svfdev        print(module)
170*f136c800Svfdev    return f
171*f136c800Svfdev
172*f136c800Svfdev
173*f136c800Svfdevdef run_transform_structured_ext(f):
174*f136c800Svfdev    with Context(), Location.unknown():
175*f136c800Svfdev        module = Module.create()
176*f136c800Svfdev        with InsertionPoint(module.body):
177*f136c800Svfdev            print("\nTEST:", f.__name__)
178*f136c800Svfdev            f()
179*f136c800Svfdev        module.operation.verify()
180*f136c800Svfdev        print(module)
181*f136c800Svfdev    return f
182*f136c800Svfdev
183*f136c800Svfdev
184*f136c800Svfdevdef run_construct_and_print_in_module(f):
185*f136c800Svfdev    print("\nTEST:", f.__name__)
186*f136c800Svfdev    with Context(), Location.unknown():
187*f136c800Svfdev        module = Module.create()
188*f136c800Svfdev        with InsertionPoint(module.body):
189*f136c800Svfdev            module = f(module)
190*f136c800Svfdev        if module is not None:
191*f136c800Svfdev            print(module)
192*f136c800Svfdev    return f
193*f136c800Svfdev
194*f136c800Svfdev
195*f136c800SvfdevTEST_MODULES = [
196*f136c800Svfdev    ("execution_engine", run),
197*f136c800Svfdev    ("pass_manager", run),
198*f136c800Svfdev    ("dialects/affine", run_with_insertion_point_v2),
199*f136c800Svfdev    ("dialects/func", run_with_insertion_point_v2),
200*f136c800Svfdev    ("dialects/arith_dialect", run),
201*f136c800Svfdev    ("dialects/arith_llvm", run),
202*f136c800Svfdev    ("dialects/async_dialect", run),
203*f136c800Svfdev    ("dialects/builtin", run),
204*f136c800Svfdev    ("dialects/cf", run_with_insertion_point_v4),
205*f136c800Svfdev    ("dialects/complex_dialect", run),
206*f136c800Svfdev    ("dialects/func", run_with_insertion_point_v2),
207*f136c800Svfdev    ("dialects/index_dialect", run_with_insertion_point),
208*f136c800Svfdev    ("dialects/llvm", run_with_insertion_point_v2),
209*f136c800Svfdev    ("dialects/math_dialect", run),
210*f136c800Svfdev    ("dialects/memref", run),
211*f136c800Svfdev    ("dialects/ml_program", run_with_insertion_point_v2),
212*f136c800Svfdev    ("dialects/nvgpu", run_with_insertion_point_v2),
213*f136c800Svfdev    ("dialects/nvvm", run_with_insertion_point_v2),
214*f136c800Svfdev    ("dialects/ods_helpers", run),
215*f136c800Svfdev    ("dialects/openmp_ops", run_with_insertion_point_v2),
216*f136c800Svfdev    ("dialects/pdl_ops", run_with_insertion_point_v2),
217*f136c800Svfdev    # ("dialects/python_test", run),  # TODO: Need to pass pybind11 or nanobind argv
218*f136c800Svfdev    ("dialects/quant", run),
219*f136c800Svfdev    ("dialects/rocdl", run_with_insertion_point_v2),
220*f136c800Svfdev    ("dialects/scf", run_with_insertion_point_v2),
221*f136c800Svfdev    ("dialects/shape", run),
222*f136c800Svfdev    ("dialects/spirv_dialect", run),
223*f136c800Svfdev    ("dialects/tensor", run),
224*f136c800Svfdev    # ("dialects/tosa", ),  # Nothing to test
225*f136c800Svfdev    ("dialects/transform_bufferization_ext", run_with_insertion_point_v2),
226*f136c800Svfdev    # ("dialects/transform_extras", ),  # Needs a more complicated execution schema
227*f136c800Svfdev    ("dialects/transform_gpu_ext", run_transform_tensor_ext),
228*f136c800Svfdev    (
229*f136c800Svfdev        "dialects/transform_interpreter",
230*f136c800Svfdev        run_with_context_and_location,
231*f136c800Svfdev        ["print_", "transform_options", "failed", "include"],
232*f136c800Svfdev    ),
233*f136c800Svfdev    (
234*f136c800Svfdev        "dialects/transform_loop_ext",
235*f136c800Svfdev        run_with_insertion_point_v2,
236*f136c800Svfdev        ["loopOutline"],
237*f136c800Svfdev    ),
238*f136c800Svfdev    ("dialects/transform_memref_ext", run_with_insertion_point_v2),
239*f136c800Svfdev    ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2),
240*f136c800Svfdev    ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
241*f136c800Svfdev    ("dialects/transform_structured_ext", run_transform_structured_ext),
242*f136c800Svfdev    ("dialects/transform_tensor_ext", run_transform_tensor_ext),
243*f136c800Svfdev    (
244*f136c800Svfdev        "dialects/transform_vector_ext",
245*f136c800Svfdev        run_apply_patterns,
246*f136c800Svfdev        ["configurable_patterns"],
247*f136c800Svfdev    ),
248*f136c800Svfdev    ("dialects/transform", run_with_insertion_point_v3),
249*f136c800Svfdev    ("dialects/vector", run_with_context_and_location),
250*f136c800Svfdev    ("dialects/gpu/dialect", run_with_context_and_location),
251*f136c800Svfdev    ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
252*f136c800Svfdev    ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
253*f136c800Svfdev    ("dialects/linalg/ops", run),
254*f136c800Svfdev    # TO ADD: No proper tests in this dialects/linalg/opsdsl/*
255*f136c800Svfdev    # ("dialects/linalg/opsdsl/*", ...),
256*f136c800Svfdev    ("dialects/sparse_tensor/dialect", run),
257*f136c800Svfdev    ("dialects/sparse_tensor/passes", run),
258*f136c800Svfdev    ("integration/dialects/pdl", run_construct_and_print_in_module),
259*f136c800Svfdev    ("integration/dialects/transform", run_construct_and_print_in_module),
260*f136c800Svfdev    ("integration/dialects/linalg/opsrun", run),
261*f136c800Svfdev    ("ir/affine_expr", run),
262*f136c800Svfdev    ("ir/affine_map", run),
263*f136c800Svfdev    ("ir/array_attributes", run),
264*f136c800Svfdev    ("ir/attributes", run),
265*f136c800Svfdev    ("ir/blocks", run),
266*f136c800Svfdev    ("ir/builtin_types", run),
267*f136c800Svfdev    ("ir/context_managers", run),
268*f136c800Svfdev    ("ir/debug", run),
269*f136c800Svfdev    ("ir/diagnostic_handler", run),
270*f136c800Svfdev    ("ir/dialects", run),
271*f136c800Svfdev    ("ir/exception", run),
272*f136c800Svfdev    ("ir/insertion_point", run),
273*f136c800Svfdev    ("ir/integer_set", run),
274*f136c800Svfdev    ("ir/location", run),
275*f136c800Svfdev    ("ir/module", run),
276*f136c800Svfdev    ("ir/operation", run),
277*f136c800Svfdev    ("ir/symbol_table", run),
278*f136c800Svfdev    ("ir/value", run),
279*f136c800Svfdev]
280*f136c800Svfdev
281*f136c800SvfdevTESTS_TO_SKIP = [
282*f136c800Svfdev    "test_execution_engine__testNanoTime_multi_threaded",  # testNanoTime can't run in multiple threads, even with GIL
283*f136c800Svfdev    "test_execution_engine__testSharedLibLoad_multi_threaded",  # testSharedLibLoad can't run in multiple threads, even with GIL
284*f136c800Svfdev    "test_dialects_arith_dialect__testArithValue_multi_threaded",  # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
285*f136c800Svfdev    "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded",  # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
286*f136c800Svfdev    "test_ir_value__testValueCasters_multi_threaded",  # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
287*f136c800Svfdev    # tests indirectly calling thread-unsafe llvm::raw_ostream
288*f136c800Svfdev    "test_execution_engine__testInvalidModule_multi_threaded",  # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
289*f136c800Svfdev    "test_pass_manager__testPrintIrAfterAll_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
290*f136c800Svfdev    "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded",  # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
291*f136c800Svfdev    "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
292*f136c800Svfdev    "test_pass_manager__testPrintIrTree_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
293*f136c800Svfdev    "test_pass_manager__testRunPipeline_multi_threaded",  # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
294*f136c800Svfdev    "test_dialects_transform_interpreter__include_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
295*f136c800Svfdev    "test_dialects_transform_interpreter__transform_options_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
296*f136c800Svfdev    "test_dialects_transform_interpreter__print_self_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
297*f136c800Svfdev    "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded",  # mlirEmitError calls thread-unsafe llvm::raw_ostream
298*f136c800Svfdev    "test_ir_module__testParseSuccess_multi_threaded",  # mlirOperationDump calls thread-unsafe llvm::raw_ostream
299*f136c800Svfdev    # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
300*f136c800Svfdev    # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
301*f136c800Svfdev    "test_execution_engine__testCapsule_multi_threaded",
302*f136c800Svfdev    "test_execution_engine__testDumpToObjectFile_multi_threaded",
303*f136c800Svfdev]
304*f136c800Svfdev
305*f136c800SvfdevTESTS_TO_XFAIL = [
306*f136c800Svfdev    # execution_engine tests:
307*f136c800Svfdev    # - ctypes related data-races: https://github.com/python/cpython/issues/127945
308*f136c800Svfdev    "test_execution_engine__testBF16Memref_multi_threaded",
309*f136c800Svfdev    "test_execution_engine__testBasicCallback_multi_threaded",
310*f136c800Svfdev    "test_execution_engine__testComplexMemrefAdd_multi_threaded",
311*f136c800Svfdev    "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
312*f136c800Svfdev    "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
313*f136c800Svfdev    "test_execution_engine__testF16MemrefAdd_multi_threaded",
314*f136c800Svfdev    "test_execution_engine__testF8E5M2Memref_multi_threaded",
315*f136c800Svfdev    "test_execution_engine__testInvokeFloatAdd_multi_threaded",
316*f136c800Svfdev    "test_execution_engine__testInvokeVoid_multi_threaded",  # a ctypes race
317*f136c800Svfdev    "test_execution_engine__testMemrefAdd_multi_threaded",
318*f136c800Svfdev    "test_execution_engine__testRankedMemRefCallback_multi_threaded",
319*f136c800Svfdev    "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
320*f136c800Svfdev    "test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
321*f136c800Svfdev    "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
322*f136c800Svfdev    # dialects tests
323*f136c800Svfdev    "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded",  # Related to ctypes data races
324*f136c800Svfdev    "test_dialects_transform_interpreter__print_other_multi_threaded",  # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
325*f136c800Svfdev    "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded",  # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
326*f136c800Svfdev    "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
327*f136c800Svfdev    "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
328*f136c800Svfdev    "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
329*f136c800Svfdev    # integration tests
330*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded",  # Related to ctypes data races
331*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded",  # Related to ctypes data races
332*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded",  # ctypes
333*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded",  # ctypes
334*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded",  # ctypes
335*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded",  # ctypes
336*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded",  # ctypes
337*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded",  # ctypes
338*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded",  # ctypes
339*f136c800Svfdev    "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded",  # ctypes
340*f136c800Svfdev]
341*f136c800Svfdev
342*f136c800Svfdev
343*f136c800Svfdevdef add_existing_tests(test_modules, test_prefix: str = "_original_test"):
344*f136c800Svfdev    def decorator(test_cls):
345*f136c800Svfdev        this_folder = Path(__file__).parent.absolute()
346*f136c800Svfdev        test_cls.output_folder = tempfile.TemporaryDirectory()
347*f136c800Svfdev        output_folder = Path(test_cls.output_folder.name)
348*f136c800Svfdev
349*f136c800Svfdev        for test_mod_info in test_modules:
350*f136c800Svfdev            assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
351*f136c800Svfdev            if len(test_mod_info) == 2:
352*f136c800Svfdev                test_module_name, exec_fn = test_mod_info
353*f136c800Svfdev                test_pattern = None
354*f136c800Svfdev            else:
355*f136c800Svfdev                test_module_name, exec_fn, test_pattern = test_mod_info
356*f136c800Svfdev
357*f136c800Svfdev            src_filepath = this_folder / f"{test_module_name}.py"
358*f136c800Svfdev            dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
359*f136c800Svfdev            if not dst_filepath.parent.exists():
360*f136c800Svfdev                dst_filepath.parent.mkdir(parents=True)
361*f136c800Svfdev            copy_and_update(src_filepath, dst_filepath)
362*f136c800Svfdev            test_mod = import_from_path(test_module_name, dst_filepath)
363*f136c800Svfdev            for attr_name in dir(test_mod):
364*f136c800Svfdev                is_test_fn = test_pattern is None and attr_name.startswith("test")
365*f136c800Svfdev                is_test_fn |= test_pattern is not None and any(
366*f136c800Svfdev                    [p in attr_name for p in test_pattern]
367*f136c800Svfdev                )
368*f136c800Svfdev                if is_test_fn:
369*f136c800Svfdev                    obj = getattr(test_mod, attr_name)
370*f136c800Svfdev                    if callable(obj):
371*f136c800Svfdev                        test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
372*f136c800Svfdev
373*f136c800Svfdev                        def wrapped_test_fn(
374*f136c800Svfdev                            self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
375*f136c800Svfdev                        ):
376*f136c800Svfdev                            __exec_fn__(__test_fn__)
377*f136c800Svfdev
378*f136c800Svfdev                        setattr(test_cls, test_name, wrapped_test_fn)
379*f136c800Svfdev        return test_cls
380*f136c800Svfdev
381*f136c800Svfdev    return decorator
382*f136c800Svfdev
383*f136c800Svfdev
384*f136c800Svfdev@contextmanager
385*f136c800Svfdevdef _capture_output(fp):
386*f136c800Svfdev    # Inspired from jax test_utils.py capture_stderr method
387*f136c800Svfdev    # ``None`` means nothing has not been captured yet.
388*f136c800Svfdev    captured = None
389*f136c800Svfdev
390*f136c800Svfdev    def get_output() -> str:
391*f136c800Svfdev        if captured is None:
392*f136c800Svfdev            raise ValueError("get_output() called while the context is active.")
393*f136c800Svfdev        return captured
394*f136c800Svfdev
395*f136c800Svfdev    with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
396*f136c800Svfdev        original_fd = os.dup(fp.fileno())
397*f136c800Svfdev        os.dup2(f.fileno(), fp.fileno())
398*f136c800Svfdev        try:
399*f136c800Svfdev            yield get_output
400*f136c800Svfdev        finally:
401*f136c800Svfdev            # Python also has its own buffers, make sure everything is flushed.
402*f136c800Svfdev            fp.flush()
403*f136c800Svfdev            os.fsync(fp.fileno())
404*f136c800Svfdev            f.seek(0)
405*f136c800Svfdev            captured = f.read()
406*f136c800Svfdev            os.dup2(original_fd, fp.fileno())
407*f136c800Svfdev
408*f136c800Svfdev
409*f136c800Svfdevcapture_stdout = partial(_capture_output, sys.stdout)
410*f136c800Svfdevcapture_stderr = partial(_capture_output, sys.stderr)
411*f136c800Svfdev
412*f136c800Svfdev
413*f136c800Svfdevdef multi_threaded(
414*f136c800Svfdev    num_workers: int,
415*f136c800Svfdev    num_runs: int = 5,
416*f136c800Svfdev    skip_tests: Optional[List[str]] = None,
417*f136c800Svfdev    xfail_tests: Optional[List[str]] = None,
418*f136c800Svfdev    test_prefix: str = "_original_test",
419*f136c800Svfdev    multithreaded_test_postfix: str = "_multi_threaded",
420*f136c800Svfdev):
421*f136c800Svfdev    """Decorator that runs a test in a multi-threaded environment."""
422*f136c800Svfdev
423*f136c800Svfdev    def decorator(test_cls):
424*f136c800Svfdev        for name, test_fn in test_cls.__dict__.copy().items():
425*f136c800Svfdev            if not (name.startswith(test_prefix) and callable(test_fn)):
426*f136c800Svfdev                continue
427*f136c800Svfdev
428*f136c800Svfdev            name = f"test{name[len(test_prefix):]}"
429*f136c800Svfdev            if skip_tests is not None:
430*f136c800Svfdev                if any(
431*f136c800Svfdev                    test_name.replace(multithreaded_test_postfix, "") in name
432*f136c800Svfdev                    for test_name in skip_tests
433*f136c800Svfdev                ):
434*f136c800Svfdev                    continue
435*f136c800Svfdev
436*f136c800Svfdev            def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
437*f136c800Svfdev                with capture_stdout(), capture_stderr() as get_output:
438*f136c800Svfdev                    barrier = threading.Barrier(num_workers)
439*f136c800Svfdev
440*f136c800Svfdev                    def closure():
441*f136c800Svfdev                        barrier.wait()
442*f136c800Svfdev                        for _ in range(num_runs):
443*f136c800Svfdev                            __test_fn__(self, *args, **kwargs)
444*f136c800Svfdev
445*f136c800Svfdev                    with concurrent.futures.ThreadPoolExecutor(
446*f136c800Svfdev                        max_workers=num_workers
447*f136c800Svfdev                    ) as executor:
448*f136c800Svfdev                        futures = []
449*f136c800Svfdev                        for _ in range(num_workers):
450*f136c800Svfdev                            futures.append(executor.submit(closure))
451*f136c800Svfdev                        # We should call future.result() to re-raise an exception if test has
452*f136c800Svfdev                        # failed
453*f136c800Svfdev                        assert len(list(f.result() for f in futures)) == num_workers
454*f136c800Svfdev
455*f136c800Svfdev                    gc.collect()
456*f136c800Svfdev                    assert Context._get_live_count() == 0
457*f136c800Svfdev
458*f136c800Svfdev                captured = get_output()
459*f136c800Svfdev                if len(captured) > 0 and "ThreadSanitizer" in captured:
460*f136c800Svfdev                    raise RuntimeError(
461*f136c800Svfdev                        f"ThreadSanitizer reported warnings:\n{captured}"
462*f136c800Svfdev                    )
463*f136c800Svfdev
464*f136c800Svfdev            test_new_name = f"{name}{multithreaded_test_postfix}"
465*f136c800Svfdev            if xfail_tests is not None and test_new_name in xfail_tests:
466*f136c800Svfdev                multi_threaded_test_fn = unittest.expectedFailure(
467*f136c800Svfdev                    multi_threaded_test_fn
468*f136c800Svfdev                )
469*f136c800Svfdev
470*f136c800Svfdev            setattr(test_cls, test_new_name, multi_threaded_test_fn)
471*f136c800Svfdev
472*f136c800Svfdev        return test_cls
473*f136c800Svfdev
474*f136c800Svfdev    return decorator
475*f136c800Svfdev
476*f136c800Svfdev
477*f136c800Svfdev@multi_threaded(
478*f136c800Svfdev    num_workers=10,
479*f136c800Svfdev    num_runs=20,
480*f136c800Svfdev    skip_tests=TESTS_TO_SKIP,
481*f136c800Svfdev    xfail_tests=TESTS_TO_XFAIL,
482*f136c800Svfdev)
483*f136c800Svfdev@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
484*f136c800Svfdevclass TestAllMultiThreaded(unittest.TestCase):
485*f136c800Svfdev    @classmethod
486*f136c800Svfdev    def tearDownClass(cls):
487*f136c800Svfdev        if hasattr(cls, "output_folder"):
488*f136c800Svfdev            cls.output_folder.cleanup()
489*f136c800Svfdev
490*f136c800Svfdev    def _original_test_create_context(self):
491*f136c800Svfdev        with Context() as ctx:
492*f136c800Svfdev            print(ctx._get_live_count())
493*f136c800Svfdev            print(ctx._get_live_module_count())
494*f136c800Svfdev            print(ctx._get_live_operation_count())
495*f136c800Svfdev            print(ctx._get_live_operation_objects())
496*f136c800Svfdev            print(ctx._get_context_again() is ctx)
497*f136c800Svfdev            print(ctx._clear_live_operations())
498*f136c800Svfdev
499*f136c800Svfdev    def _original_test_create_module_with_consts(self):
500*f136c800Svfdev        py_values = [123, 234, 345]
501*f136c800Svfdev        with Context() as ctx:
502*f136c800Svfdev            module = Module.create(loc=Location.file("foo.txt", 0, 0))
503*f136c800Svfdev
504*f136c800Svfdev            dtype = IntegerType.get_signless(64)
505*f136c800Svfdev            with InsertionPoint(module.body), Location.name("a"):
506*f136c800Svfdev                arith.constant(dtype, py_values[0])
507*f136c800Svfdev
508*f136c800Svfdev            with InsertionPoint(module.body), Location.name("b"):
509*f136c800Svfdev                arith.constant(dtype, py_values[1])
510*f136c800Svfdev
511*f136c800Svfdev            with InsertionPoint(module.body), Location.name("c"):
512*f136c800Svfdev                arith.constant(dtype, py_values[2])
513*f136c800Svfdev
514*f136c800Svfdev
515*f136c800Svfdevif __name__ == "__main__":
516*f136c800Svfdev    # Do not run the tests on CPython with GIL
517*f136c800Svfdev    if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
518*f136c800Svfdev        unittest.main()
519