1*f136c800Svfdev# RUN: %PYTHON %s 2*f136c800Svfdev""" 3*f136c800SvfdevThis script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN. 4*f136c800SvfdevTests can be run using pytest: 5*f136c800Svfdev```bash 6*f136c800Svfdevpython3.13t -mpytest -vvv multithreaded_tests.py 7*f136c800Svfdev``` 8*f136c800Svfdev 9*f136c800SvfdevIMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context 10*f136c800Svfdevand passing if no warnings reported by TSAN and failing otherwise. 11*f136c800Svfdev 12*f136c800Svfdev 13*f136c800SvfdevDetails on the generated tests and execution: 14*f136c800Svfdev1) Multi-threaded execution: all generated tests are executed independently by 15*f136c800Svfdeva pool of threads, running each test multiple times, see @multi_threaded for details 16*f136c800Svfdev 17*f136c800Svfdev2) Tests generation: we use existing tests: test/python/ir/*.py, 18*f136c800Svfdevtest/python/dialects/*.py, etc to generate multi-threaded tests. 19*f136c800SvfdevIn details, we perform the following: 20*f136c800Svfdeva) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`. 21*f136c800Svfdevb) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method. 22*f136c800Svfdevc) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py. 23*f136c800SvfdevIn order to import the test file as python module, we remove all executing functions, like 24*f136c800Svfdev`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details. 25*f136c800Svfdev 26*f136c800Svfdev 27*f136c800SvfdevObserved warnings reported by TSAN. 28*f136c800Svfdev 29*f136c800SvfdevCPython and free-threading known data-races: 30*f136c800Svfdev1) ctypes related races: https://github.com/python/cpython/issues/127945 31*f136c800Svfdev2) LLVM related data-races, llvm::raw_ostream is not thread-safe 32*f136c800Svfdev- mlir pass manager 33*f136c800Svfdev- dialects/transform_interpreter.py 34*f136c800Svfdev- ir/diagnostic_handler.py 35*f136c800Svfdev- ir/module.py 36*f136c800Svfdev3) Dialect gpu module-to-binary method is unsafe 37*f136c800Svfdev""" 38*f136c800Svfdevimport concurrent.futures 39*f136c800Svfdevimport gc 40*f136c800Svfdevimport importlib.util 41*f136c800Svfdevimport os 42*f136c800Svfdevimport sys 43*f136c800Svfdevimport threading 44*f136c800Svfdevimport tempfile 45*f136c800Svfdevimport unittest 46*f136c800Svfdev 47*f136c800Svfdevfrom contextlib import contextmanager 48*f136c800Svfdevfrom functools import partial 49*f136c800Svfdevfrom pathlib import Path 50*f136c800Svfdevfrom typing import Optional, List 51*f136c800Svfdev 52*f136c800Svfdevimport mlir.dialects.arith as arith 53*f136c800Svfdevfrom mlir.dialects import transform 54*f136c800Svfdevfrom mlir.ir import Context, Location, Module, IntegerType, InsertionPoint 55*f136c800Svfdev 56*f136c800Svfdev 57*f136c800Svfdevdef import_from_path(module_name: str, file_path: Path): 58*f136c800Svfdev spec = importlib.util.spec_from_file_location(module_name, file_path) 59*f136c800Svfdev module = importlib.util.module_from_spec(spec) 60*f136c800Svfdev sys.modules[module_name] = module 61*f136c800Svfdev spec.loader.exec_module(module) 62*f136c800Svfdev return module 63*f136c800Svfdev 64*f136c800Svfdev 65*f136c800Svfdevdef copy_and_update(src_filepath: Path, dst_filepath: Path): 66*f136c800Svfdev # We should remove all calls like `run(testMethod)` 67*f136c800Svfdev with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer: 68*f136c800Svfdev while True: 69*f136c800Svfdev src_line = reader.readline() 70*f136c800Svfdev if len(src_line) == 0: 71*f136c800Svfdev break 72*f136c800Svfdev skip_lines = [ 73*f136c800Svfdev "run(", 74*f136c800Svfdev "@run", 75*f136c800Svfdev "@constructAndPrintInModule", 76*f136c800Svfdev "run_apply_patterns(", 77*f136c800Svfdev "@run_apply_patterns", 78*f136c800Svfdev "@test_in_context", 79*f136c800Svfdev "@construct_and_print_in_module", 80*f136c800Svfdev ] 81*f136c800Svfdev if any(src_line.startswith(line) for line in skip_lines): 82*f136c800Svfdev continue 83*f136c800Svfdev writer.write(src_line) 84*f136c800Svfdev 85*f136c800Svfdev 86*f136c800Svfdev# Helper run functions 87*f136c800Svfdevdef run(f): 88*f136c800Svfdev f() 89*f136c800Svfdev 90*f136c800Svfdev 91*f136c800Svfdevdef run_with_context_and_location(f): 92*f136c800Svfdev print("\nTEST:", f.__name__) 93*f136c800Svfdev with Context(), Location.unknown(): 94*f136c800Svfdev f() 95*f136c800Svfdev return f 96*f136c800Svfdev 97*f136c800Svfdev 98*f136c800Svfdevdef run_with_insertion_point(f): 99*f136c800Svfdev print("\nTEST:", f.__name__) 100*f136c800Svfdev with Context() as ctx, Location.unknown(): 101*f136c800Svfdev module = Module.create() 102*f136c800Svfdev with InsertionPoint(module.body): 103*f136c800Svfdev f(ctx) 104*f136c800Svfdev print(module) 105*f136c800Svfdev 106*f136c800Svfdev 107*f136c800Svfdevdef run_with_insertion_point_v2(f): 108*f136c800Svfdev print("\nTEST:", f.__name__) 109*f136c800Svfdev with Context(), Location.unknown(): 110*f136c800Svfdev module = Module.create() 111*f136c800Svfdev with InsertionPoint(module.body): 112*f136c800Svfdev f() 113*f136c800Svfdev print(module) 114*f136c800Svfdev return f 115*f136c800Svfdev 116*f136c800Svfdev 117*f136c800Svfdevdef run_with_insertion_point_v3(f): 118*f136c800Svfdev with Context(), Location.unknown(): 119*f136c800Svfdev module = Module.create() 120*f136c800Svfdev with InsertionPoint(module.body): 121*f136c800Svfdev print("\nTEST:", f.__name__) 122*f136c800Svfdev f(module) 123*f136c800Svfdev print(module) 124*f136c800Svfdev return f 125*f136c800Svfdev 126*f136c800Svfdev 127*f136c800Svfdevdef run_with_insertion_point_v4(f): 128*f136c800Svfdev print("\nTEST:", f.__name__) 129*f136c800Svfdev with Context() as ctx, Location.unknown(): 130*f136c800Svfdev ctx.allow_unregistered_dialects = True 131*f136c800Svfdev module = Module.create() 132*f136c800Svfdev with InsertionPoint(module.body): 133*f136c800Svfdev f() 134*f136c800Svfdev return f 135*f136c800Svfdev 136*f136c800Svfdev 137*f136c800Svfdevdef run_apply_patterns(f): 138*f136c800Svfdev with Context(), Location.unknown(): 139*f136c800Svfdev module = Module.create() 140*f136c800Svfdev with InsertionPoint(module.body): 141*f136c800Svfdev sequence = transform.SequenceOp( 142*f136c800Svfdev transform.FailurePropagationMode.Propagate, 143*f136c800Svfdev [], 144*f136c800Svfdev transform.AnyOpType.get(), 145*f136c800Svfdev ) 146*f136c800Svfdev with InsertionPoint(sequence.body): 147*f136c800Svfdev apply = transform.ApplyPatternsOp(sequence.bodyTarget) 148*f136c800Svfdev with InsertionPoint(apply.patterns): 149*f136c800Svfdev f() 150*f136c800Svfdev transform.YieldOp() 151*f136c800Svfdev print("\nTEST:", f.__name__) 152*f136c800Svfdev print(module) 153*f136c800Svfdev return f 154*f136c800Svfdev 155*f136c800Svfdev 156*f136c800Svfdevdef run_transform_tensor_ext(f): 157*f136c800Svfdev print("\nTEST:", f.__name__) 158*f136c800Svfdev with Context(), Location.unknown(): 159*f136c800Svfdev module = Module.create() 160*f136c800Svfdev with InsertionPoint(module.body): 161*f136c800Svfdev sequence = transform.SequenceOp( 162*f136c800Svfdev transform.FailurePropagationMode.Propagate, 163*f136c800Svfdev [], 164*f136c800Svfdev transform.AnyOpType.get(), 165*f136c800Svfdev ) 166*f136c800Svfdev with InsertionPoint(sequence.body): 167*f136c800Svfdev f(sequence.bodyTarget) 168*f136c800Svfdev transform.YieldOp() 169*f136c800Svfdev print(module) 170*f136c800Svfdev return f 171*f136c800Svfdev 172*f136c800Svfdev 173*f136c800Svfdevdef run_transform_structured_ext(f): 174*f136c800Svfdev with Context(), Location.unknown(): 175*f136c800Svfdev module = Module.create() 176*f136c800Svfdev with InsertionPoint(module.body): 177*f136c800Svfdev print("\nTEST:", f.__name__) 178*f136c800Svfdev f() 179*f136c800Svfdev module.operation.verify() 180*f136c800Svfdev print(module) 181*f136c800Svfdev return f 182*f136c800Svfdev 183*f136c800Svfdev 184*f136c800Svfdevdef run_construct_and_print_in_module(f): 185*f136c800Svfdev print("\nTEST:", f.__name__) 186*f136c800Svfdev with Context(), Location.unknown(): 187*f136c800Svfdev module = Module.create() 188*f136c800Svfdev with InsertionPoint(module.body): 189*f136c800Svfdev module = f(module) 190*f136c800Svfdev if module is not None: 191*f136c800Svfdev print(module) 192*f136c800Svfdev return f 193*f136c800Svfdev 194*f136c800Svfdev 195*f136c800SvfdevTEST_MODULES = [ 196*f136c800Svfdev ("execution_engine", run), 197*f136c800Svfdev ("pass_manager", run), 198*f136c800Svfdev ("dialects/affine", run_with_insertion_point_v2), 199*f136c800Svfdev ("dialects/func", run_with_insertion_point_v2), 200*f136c800Svfdev ("dialects/arith_dialect", run), 201*f136c800Svfdev ("dialects/arith_llvm", run), 202*f136c800Svfdev ("dialects/async_dialect", run), 203*f136c800Svfdev ("dialects/builtin", run), 204*f136c800Svfdev ("dialects/cf", run_with_insertion_point_v4), 205*f136c800Svfdev ("dialects/complex_dialect", run), 206*f136c800Svfdev ("dialects/func", run_with_insertion_point_v2), 207*f136c800Svfdev ("dialects/index_dialect", run_with_insertion_point), 208*f136c800Svfdev ("dialects/llvm", run_with_insertion_point_v2), 209*f136c800Svfdev ("dialects/math_dialect", run), 210*f136c800Svfdev ("dialects/memref", run), 211*f136c800Svfdev ("dialects/ml_program", run_with_insertion_point_v2), 212*f136c800Svfdev ("dialects/nvgpu", run_with_insertion_point_v2), 213*f136c800Svfdev ("dialects/nvvm", run_with_insertion_point_v2), 214*f136c800Svfdev ("dialects/ods_helpers", run), 215*f136c800Svfdev ("dialects/openmp_ops", run_with_insertion_point_v2), 216*f136c800Svfdev ("dialects/pdl_ops", run_with_insertion_point_v2), 217*f136c800Svfdev # ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv 218*f136c800Svfdev ("dialects/quant", run), 219*f136c800Svfdev ("dialects/rocdl", run_with_insertion_point_v2), 220*f136c800Svfdev ("dialects/scf", run_with_insertion_point_v2), 221*f136c800Svfdev ("dialects/shape", run), 222*f136c800Svfdev ("dialects/spirv_dialect", run), 223*f136c800Svfdev ("dialects/tensor", run), 224*f136c800Svfdev # ("dialects/tosa", ), # Nothing to test 225*f136c800Svfdev ("dialects/transform_bufferization_ext", run_with_insertion_point_v2), 226*f136c800Svfdev # ("dialects/transform_extras", ), # Needs a more complicated execution schema 227*f136c800Svfdev ("dialects/transform_gpu_ext", run_transform_tensor_ext), 228*f136c800Svfdev ( 229*f136c800Svfdev "dialects/transform_interpreter", 230*f136c800Svfdev run_with_context_and_location, 231*f136c800Svfdev ["print_", "transform_options", "failed", "include"], 232*f136c800Svfdev ), 233*f136c800Svfdev ( 234*f136c800Svfdev "dialects/transform_loop_ext", 235*f136c800Svfdev run_with_insertion_point_v2, 236*f136c800Svfdev ["loopOutline"], 237*f136c800Svfdev ), 238*f136c800Svfdev ("dialects/transform_memref_ext", run_with_insertion_point_v2), 239*f136c800Svfdev ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2), 240*f136c800Svfdev ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext), 241*f136c800Svfdev ("dialects/transform_structured_ext", run_transform_structured_ext), 242*f136c800Svfdev ("dialects/transform_tensor_ext", run_transform_tensor_ext), 243*f136c800Svfdev ( 244*f136c800Svfdev "dialects/transform_vector_ext", 245*f136c800Svfdev run_apply_patterns, 246*f136c800Svfdev ["configurable_patterns"], 247*f136c800Svfdev ), 248*f136c800Svfdev ("dialects/transform", run_with_insertion_point_v3), 249*f136c800Svfdev ("dialects/vector", run_with_context_and_location), 250*f136c800Svfdev ("dialects/gpu/dialect", run_with_context_and_location), 251*f136c800Svfdev ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location), 252*f136c800Svfdev ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location), 253*f136c800Svfdev ("dialects/linalg/ops", run), 254*f136c800Svfdev # TO ADD: No proper tests in this dialects/linalg/opsdsl/* 255*f136c800Svfdev # ("dialects/linalg/opsdsl/*", ...), 256*f136c800Svfdev ("dialects/sparse_tensor/dialect", run), 257*f136c800Svfdev ("dialects/sparse_tensor/passes", run), 258*f136c800Svfdev ("integration/dialects/pdl", run_construct_and_print_in_module), 259*f136c800Svfdev ("integration/dialects/transform", run_construct_and_print_in_module), 260*f136c800Svfdev ("integration/dialects/linalg/opsrun", run), 261*f136c800Svfdev ("ir/affine_expr", run), 262*f136c800Svfdev ("ir/affine_map", run), 263*f136c800Svfdev ("ir/array_attributes", run), 264*f136c800Svfdev ("ir/attributes", run), 265*f136c800Svfdev ("ir/blocks", run), 266*f136c800Svfdev ("ir/builtin_types", run), 267*f136c800Svfdev ("ir/context_managers", run), 268*f136c800Svfdev ("ir/debug", run), 269*f136c800Svfdev ("ir/diagnostic_handler", run), 270*f136c800Svfdev ("ir/dialects", run), 271*f136c800Svfdev ("ir/exception", run), 272*f136c800Svfdev ("ir/insertion_point", run), 273*f136c800Svfdev ("ir/integer_set", run), 274*f136c800Svfdev ("ir/location", run), 275*f136c800Svfdev ("ir/module", run), 276*f136c800Svfdev ("ir/operation", run), 277*f136c800Svfdev ("ir/symbol_table", run), 278*f136c800Svfdev ("ir/value", run), 279*f136c800Svfdev] 280*f136c800Svfdev 281*f136c800SvfdevTESTS_TO_SKIP = [ 282*f136c800Svfdev "test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL 283*f136c800Svfdev "test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL 284*f136c800Svfdev "test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL 285*f136c800Svfdev "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals 286*f136c800Svfdev "test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL 287*f136c800Svfdev # tests indirectly calling thread-unsafe llvm::raw_ostream 288*f136c800Svfdev "test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream 289*f136c800Svfdev "test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream 290*f136c800Svfdev "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream 291*f136c800Svfdev "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream 292*f136c800Svfdev "test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream 293*f136c800Svfdev "test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream 294*f136c800Svfdev "test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream 295*f136c800Svfdev "test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream 296*f136c800Svfdev "test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream 297*f136c800Svfdev "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream 298*f136c800Svfdev "test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream 299*f136c800Svfdev # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames() 300*f136c800Svfdev # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947 301*f136c800Svfdev "test_execution_engine__testCapsule_multi_threaded", 302*f136c800Svfdev "test_execution_engine__testDumpToObjectFile_multi_threaded", 303*f136c800Svfdev] 304*f136c800Svfdev 305*f136c800SvfdevTESTS_TO_XFAIL = [ 306*f136c800Svfdev # execution_engine tests: 307*f136c800Svfdev # - ctypes related data-races: https://github.com/python/cpython/issues/127945 308*f136c800Svfdev "test_execution_engine__testBF16Memref_multi_threaded", 309*f136c800Svfdev "test_execution_engine__testBasicCallback_multi_threaded", 310*f136c800Svfdev "test_execution_engine__testComplexMemrefAdd_multi_threaded", 311*f136c800Svfdev "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded", 312*f136c800Svfdev "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded", 313*f136c800Svfdev "test_execution_engine__testF16MemrefAdd_multi_threaded", 314*f136c800Svfdev "test_execution_engine__testF8E5M2Memref_multi_threaded", 315*f136c800Svfdev "test_execution_engine__testInvokeFloatAdd_multi_threaded", 316*f136c800Svfdev "test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race 317*f136c800Svfdev "test_execution_engine__testMemrefAdd_multi_threaded", 318*f136c800Svfdev "test_execution_engine__testRankedMemRefCallback_multi_threaded", 319*f136c800Svfdev "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded", 320*f136c800Svfdev "test_execution_engine__testUnrankedMemRefCallback_multi_threaded", 321*f136c800Svfdev "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded", 322*f136c800Svfdev # dialects tests 323*f136c800Svfdev "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races 324*f136c800Svfdev "test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe 325*f136c800Svfdev "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation 326*f136c800Svfdev "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded", 327*f136c800Svfdev "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded", 328*f136c800Svfdev "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded", 329*f136c800Svfdev # integration tests 330*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races 331*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races 332*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes 333*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes 334*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes 335*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes 336*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes 337*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes 338*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes 339*f136c800Svfdev "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes 340*f136c800Svfdev] 341*f136c800Svfdev 342*f136c800Svfdev 343*f136c800Svfdevdef add_existing_tests(test_modules, test_prefix: str = "_original_test"): 344*f136c800Svfdev def decorator(test_cls): 345*f136c800Svfdev this_folder = Path(__file__).parent.absolute() 346*f136c800Svfdev test_cls.output_folder = tempfile.TemporaryDirectory() 347*f136c800Svfdev output_folder = Path(test_cls.output_folder.name) 348*f136c800Svfdev 349*f136c800Svfdev for test_mod_info in test_modules: 350*f136c800Svfdev assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3) 351*f136c800Svfdev if len(test_mod_info) == 2: 352*f136c800Svfdev test_module_name, exec_fn = test_mod_info 353*f136c800Svfdev test_pattern = None 354*f136c800Svfdev else: 355*f136c800Svfdev test_module_name, exec_fn, test_pattern = test_mod_info 356*f136c800Svfdev 357*f136c800Svfdev src_filepath = this_folder / f"{test_module_name}.py" 358*f136c800Svfdev dst_filepath = (output_folder / f"{test_module_name}.py").absolute() 359*f136c800Svfdev if not dst_filepath.parent.exists(): 360*f136c800Svfdev dst_filepath.parent.mkdir(parents=True) 361*f136c800Svfdev copy_and_update(src_filepath, dst_filepath) 362*f136c800Svfdev test_mod = import_from_path(test_module_name, dst_filepath) 363*f136c800Svfdev for attr_name in dir(test_mod): 364*f136c800Svfdev is_test_fn = test_pattern is None and attr_name.startswith("test") 365*f136c800Svfdev is_test_fn |= test_pattern is not None and any( 366*f136c800Svfdev [p in attr_name for p in test_pattern] 367*f136c800Svfdev ) 368*f136c800Svfdev if is_test_fn: 369*f136c800Svfdev obj = getattr(test_mod, attr_name) 370*f136c800Svfdev if callable(obj): 371*f136c800Svfdev test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}" 372*f136c800Svfdev 373*f136c800Svfdev def wrapped_test_fn( 374*f136c800Svfdev self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs 375*f136c800Svfdev ): 376*f136c800Svfdev __exec_fn__(__test_fn__) 377*f136c800Svfdev 378*f136c800Svfdev setattr(test_cls, test_name, wrapped_test_fn) 379*f136c800Svfdev return test_cls 380*f136c800Svfdev 381*f136c800Svfdev return decorator 382*f136c800Svfdev 383*f136c800Svfdev 384*f136c800Svfdev@contextmanager 385*f136c800Svfdevdef _capture_output(fp): 386*f136c800Svfdev # Inspired from jax test_utils.py capture_stderr method 387*f136c800Svfdev # ``None`` means nothing has not been captured yet. 388*f136c800Svfdev captured = None 389*f136c800Svfdev 390*f136c800Svfdev def get_output() -> str: 391*f136c800Svfdev if captured is None: 392*f136c800Svfdev raise ValueError("get_output() called while the context is active.") 393*f136c800Svfdev return captured 394*f136c800Svfdev 395*f136c800Svfdev with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f: 396*f136c800Svfdev original_fd = os.dup(fp.fileno()) 397*f136c800Svfdev os.dup2(f.fileno(), fp.fileno()) 398*f136c800Svfdev try: 399*f136c800Svfdev yield get_output 400*f136c800Svfdev finally: 401*f136c800Svfdev # Python also has its own buffers, make sure everything is flushed. 402*f136c800Svfdev fp.flush() 403*f136c800Svfdev os.fsync(fp.fileno()) 404*f136c800Svfdev f.seek(0) 405*f136c800Svfdev captured = f.read() 406*f136c800Svfdev os.dup2(original_fd, fp.fileno()) 407*f136c800Svfdev 408*f136c800Svfdev 409*f136c800Svfdevcapture_stdout = partial(_capture_output, sys.stdout) 410*f136c800Svfdevcapture_stderr = partial(_capture_output, sys.stderr) 411*f136c800Svfdev 412*f136c800Svfdev 413*f136c800Svfdevdef multi_threaded( 414*f136c800Svfdev num_workers: int, 415*f136c800Svfdev num_runs: int = 5, 416*f136c800Svfdev skip_tests: Optional[List[str]] = None, 417*f136c800Svfdev xfail_tests: Optional[List[str]] = None, 418*f136c800Svfdev test_prefix: str = "_original_test", 419*f136c800Svfdev multithreaded_test_postfix: str = "_multi_threaded", 420*f136c800Svfdev): 421*f136c800Svfdev """Decorator that runs a test in a multi-threaded environment.""" 422*f136c800Svfdev 423*f136c800Svfdev def decorator(test_cls): 424*f136c800Svfdev for name, test_fn in test_cls.__dict__.copy().items(): 425*f136c800Svfdev if not (name.startswith(test_prefix) and callable(test_fn)): 426*f136c800Svfdev continue 427*f136c800Svfdev 428*f136c800Svfdev name = f"test{name[len(test_prefix):]}" 429*f136c800Svfdev if skip_tests is not None: 430*f136c800Svfdev if any( 431*f136c800Svfdev test_name.replace(multithreaded_test_postfix, "") in name 432*f136c800Svfdev for test_name in skip_tests 433*f136c800Svfdev ): 434*f136c800Svfdev continue 435*f136c800Svfdev 436*f136c800Svfdev def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs): 437*f136c800Svfdev with capture_stdout(), capture_stderr() as get_output: 438*f136c800Svfdev barrier = threading.Barrier(num_workers) 439*f136c800Svfdev 440*f136c800Svfdev def closure(): 441*f136c800Svfdev barrier.wait() 442*f136c800Svfdev for _ in range(num_runs): 443*f136c800Svfdev __test_fn__(self, *args, **kwargs) 444*f136c800Svfdev 445*f136c800Svfdev with concurrent.futures.ThreadPoolExecutor( 446*f136c800Svfdev max_workers=num_workers 447*f136c800Svfdev ) as executor: 448*f136c800Svfdev futures = [] 449*f136c800Svfdev for _ in range(num_workers): 450*f136c800Svfdev futures.append(executor.submit(closure)) 451*f136c800Svfdev # We should call future.result() to re-raise an exception if test has 452*f136c800Svfdev # failed 453*f136c800Svfdev assert len(list(f.result() for f in futures)) == num_workers 454*f136c800Svfdev 455*f136c800Svfdev gc.collect() 456*f136c800Svfdev assert Context._get_live_count() == 0 457*f136c800Svfdev 458*f136c800Svfdev captured = get_output() 459*f136c800Svfdev if len(captured) > 0 and "ThreadSanitizer" in captured: 460*f136c800Svfdev raise RuntimeError( 461*f136c800Svfdev f"ThreadSanitizer reported warnings:\n{captured}" 462*f136c800Svfdev ) 463*f136c800Svfdev 464*f136c800Svfdev test_new_name = f"{name}{multithreaded_test_postfix}" 465*f136c800Svfdev if xfail_tests is not None and test_new_name in xfail_tests: 466*f136c800Svfdev multi_threaded_test_fn = unittest.expectedFailure( 467*f136c800Svfdev multi_threaded_test_fn 468*f136c800Svfdev ) 469*f136c800Svfdev 470*f136c800Svfdev setattr(test_cls, test_new_name, multi_threaded_test_fn) 471*f136c800Svfdev 472*f136c800Svfdev return test_cls 473*f136c800Svfdev 474*f136c800Svfdev return decorator 475*f136c800Svfdev 476*f136c800Svfdev 477*f136c800Svfdev@multi_threaded( 478*f136c800Svfdev num_workers=10, 479*f136c800Svfdev num_runs=20, 480*f136c800Svfdev skip_tests=TESTS_TO_SKIP, 481*f136c800Svfdev xfail_tests=TESTS_TO_XFAIL, 482*f136c800Svfdev) 483*f136c800Svfdev@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test") 484*f136c800Svfdevclass TestAllMultiThreaded(unittest.TestCase): 485*f136c800Svfdev @classmethod 486*f136c800Svfdev def tearDownClass(cls): 487*f136c800Svfdev if hasattr(cls, "output_folder"): 488*f136c800Svfdev cls.output_folder.cleanup() 489*f136c800Svfdev 490*f136c800Svfdev def _original_test_create_context(self): 491*f136c800Svfdev with Context() as ctx: 492*f136c800Svfdev print(ctx._get_live_count()) 493*f136c800Svfdev print(ctx._get_live_module_count()) 494*f136c800Svfdev print(ctx._get_live_operation_count()) 495*f136c800Svfdev print(ctx._get_live_operation_objects()) 496*f136c800Svfdev print(ctx._get_context_again() is ctx) 497*f136c800Svfdev print(ctx._clear_live_operations()) 498*f136c800Svfdev 499*f136c800Svfdev def _original_test_create_module_with_consts(self): 500*f136c800Svfdev py_values = [123, 234, 345] 501*f136c800Svfdev with Context() as ctx: 502*f136c800Svfdev module = Module.create(loc=Location.file("foo.txt", 0, 0)) 503*f136c800Svfdev 504*f136c800Svfdev dtype = IntegerType.get_signless(64) 505*f136c800Svfdev with InsertionPoint(module.body), Location.name("a"): 506*f136c800Svfdev arith.constant(dtype, py_values[0]) 507*f136c800Svfdev 508*f136c800Svfdev with InsertionPoint(module.body), Location.name("b"): 509*f136c800Svfdev arith.constant(dtype, py_values[1]) 510*f136c800Svfdev 511*f136c800Svfdev with InsertionPoint(module.body), Location.name("c"): 512*f136c800Svfdev arith.constant(dtype, py_values[2]) 513*f136c800Svfdev 514*f136c800Svfdev 515*f136c800Svfdevif __name__ == "__main__": 516*f136c800Svfdev # Do not run the tests on CPython with GIL 517*f136c800Svfdev if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled(): 518*f136c800Svfdev unittest.main() 519