1# RUN: %PYTHON %s 2""" 3This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN. 4Tests can be run using pytest: 5```bash 6python3.13t -mpytest -vvv multithreaded_tests.py 7``` 8 9IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context 10and passing if no warnings reported by TSAN and failing otherwise. 11 12 13Details on the generated tests and execution: 141) Multi-threaded execution: all generated tests are executed independently by 15a pool of threads, running each test multiple times, see @multi_threaded for details 16 172) Tests generation: we use existing tests: test/python/ir/*.py, 18test/python/dialects/*.py, etc to generate multi-threaded tests. 19In details, we perform the following: 20a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`. 21b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method. 22c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py. 23In order to import the test file as python module, we remove all executing functions, like 24`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details. 25 26 27Observed warnings reported by TSAN. 28 29CPython and free-threading known data-races: 301) ctypes related races: https://github.com/python/cpython/issues/127945 312) LLVM related data-races, llvm::raw_ostream is not thread-safe 32- mlir pass manager 33- dialects/transform_interpreter.py 34- ir/diagnostic_handler.py 35- ir/module.py 363) Dialect gpu module-to-binary method is unsafe 37""" 38import concurrent.futures 39import gc 40import importlib.util 41import os 42import sys 43import threading 44import tempfile 45import unittest 46 47from contextlib import contextmanager 48from functools import partial 49from pathlib import Path 50from typing import Optional, List 51 52import mlir.dialects.arith as arith 53from mlir.dialects import transform 54from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint 55 56 57def import_from_path(module_name: str, file_path: Path): 58 spec = importlib.util.spec_from_file_location(module_name, file_path) 59 module = importlib.util.module_from_spec(spec) 60 sys.modules[module_name] = module 61 spec.loader.exec_module(module) 62 return module 63 64 65def copy_and_update(src_filepath: Path, dst_filepath: Path): 66 # We should remove all calls like `run(testMethod)` 67 with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer: 68 while True: 69 src_line = reader.readline() 70 if len(src_line) == 0: 71 break 72 skip_lines = [ 73 "run(", 74 "@run", 75 "@constructAndPrintInModule", 76 "run_apply_patterns(", 77 "@run_apply_patterns", 78 "@test_in_context", 79 "@construct_and_print_in_module", 80 ] 81 if any(src_line.startswith(line) for line in skip_lines): 82 continue 83 writer.write(src_line) 84 85 86# Helper run functions 87def run(f): 88 f() 89 90 91def run_with_context_and_location(f): 92 print("\nTEST:", f.__name__) 93 with Context(), Location.unknown(): 94 f() 95 return f 96 97 98def run_with_insertion_point(f): 99 print("\nTEST:", f.__name__) 100 with Context() as ctx, Location.unknown(): 101 module = Module.create() 102 with InsertionPoint(module.body): 103 f(ctx) 104 print(module) 105 106 107def run_with_insertion_point_v2(f): 108 print("\nTEST:", f.__name__) 109 with Context(), Location.unknown(): 110 module = Module.create() 111 with InsertionPoint(module.body): 112 f() 113 print(module) 114 return f 115 116 117def run_with_insertion_point_v3(f): 118 with Context(), Location.unknown(): 119 module = Module.create() 120 with InsertionPoint(module.body): 121 print("\nTEST:", f.__name__) 122 f(module) 123 print(module) 124 return f 125 126 127def run_with_insertion_point_v4(f): 128 print("\nTEST:", f.__name__) 129 with Context() as ctx, Location.unknown(): 130 ctx.allow_unregistered_dialects = True 131 module = Module.create() 132 with InsertionPoint(module.body): 133 f() 134 return f 135 136 137def run_apply_patterns(f): 138 with Context(), Location.unknown(): 139 module = Module.create() 140 with InsertionPoint(module.body): 141 sequence = transform.SequenceOp( 142 transform.FailurePropagationMode.Propagate, 143 [], 144 transform.AnyOpType.get(), 145 ) 146 with InsertionPoint(sequence.body): 147 apply = transform.ApplyPatternsOp(sequence.bodyTarget) 148 with InsertionPoint(apply.patterns): 149 f() 150 transform.YieldOp() 151 print("\nTEST:", f.__name__) 152 print(module) 153 return f 154 155 156def run_transform_tensor_ext(f): 157 print("\nTEST:", f.__name__) 158 with Context(), Location.unknown(): 159 module = Module.create() 160 with InsertionPoint(module.body): 161 sequence = transform.SequenceOp( 162 transform.FailurePropagationMode.Propagate, 163 [], 164 transform.AnyOpType.get(), 165 ) 166 with InsertionPoint(sequence.body): 167 f(sequence.bodyTarget) 168 transform.YieldOp() 169 print(module) 170 return f 171 172 173def run_transform_structured_ext(f): 174 with Context(), Location.unknown(): 175 module = Module.create() 176 with InsertionPoint(module.body): 177 print("\nTEST:", f.__name__) 178 f() 179 module.operation.verify() 180 print(module) 181 return f 182 183 184def run_construct_and_print_in_module(f): 185 print("\nTEST:", f.__name__) 186 with Context(), Location.unknown(): 187 module = Module.create() 188 with InsertionPoint(module.body): 189 module = f(module) 190 if module is not None: 191 print(module) 192 return f 193 194 195TEST_MODULES = [ 196 ("execution_engine", run), 197 ("pass_manager", run), 198 ("dialects/affine", run_with_insertion_point_v2), 199 ("dialects/func", run_with_insertion_point_v2), 200 ("dialects/arith_dialect", run), 201 ("dialects/arith_llvm", run), 202 ("dialects/async_dialect", run), 203 ("dialects/builtin", run), 204 ("dialects/cf", run_with_insertion_point_v4), 205 ("dialects/complex_dialect", run), 206 ("dialects/func", run_with_insertion_point_v2), 207 ("dialects/index_dialect", run_with_insertion_point), 208 ("dialects/llvm", run_with_insertion_point_v2), 209 ("dialects/math_dialect", run), 210 ("dialects/memref", run), 211 ("dialects/ml_program", run_with_insertion_point_v2), 212 ("dialects/nvgpu", run_with_insertion_point_v2), 213 ("dialects/nvvm", run_with_insertion_point_v2), 214 ("dialects/ods_helpers", run), 215 ("dialects/openmp_ops", run_with_insertion_point_v2), 216 ("dialects/pdl_ops", run_with_insertion_point_v2), 217 # ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv 218 ("dialects/quant", run), 219 ("dialects/rocdl", run_with_insertion_point_v2), 220 ("dialects/scf", run_with_insertion_point_v2), 221 ("dialects/shape", run), 222 ("dialects/spirv_dialect", run), 223 ("dialects/tensor", run), 224 # ("dialects/tosa", ), # Nothing to test 225 ("dialects/transform_bufferization_ext", run_with_insertion_point_v2), 226 # ("dialects/transform_extras", ), # Needs a more complicated execution schema 227 ("dialects/transform_gpu_ext", run_transform_tensor_ext), 228 ( 229 "dialects/transform_interpreter", 230 run_with_context_and_location, 231 ["print_", "transform_options", "failed", "include"], 232 ), 233 ( 234 "dialects/transform_loop_ext", 235 run_with_insertion_point_v2, 236 ["loopOutline"], 237 ), 238 ("dialects/transform_memref_ext", run_with_insertion_point_v2), 239 ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2), 240 ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext), 241 ("dialects/transform_structured_ext", run_transform_structured_ext), 242 ("dialects/transform_tensor_ext", run_transform_tensor_ext), 243 ( 244 "dialects/transform_vector_ext", 245 run_apply_patterns, 246 ["configurable_patterns"], 247 ), 248 ("dialects/transform", run_with_insertion_point_v3), 249 ("dialects/vector", run_with_context_and_location), 250 ("dialects/gpu/dialect", run_with_context_and_location), 251 ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location), 252 ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location), 253 ("dialects/linalg/ops", run), 254 # TO ADD: No proper tests in this dialects/linalg/opsdsl/* 255 # ("dialects/linalg/opsdsl/*", ...), 256 ("dialects/sparse_tensor/dialect", run), 257 ("dialects/sparse_tensor/passes", run), 258 ("integration/dialects/pdl", run_construct_and_print_in_module), 259 ("integration/dialects/transform", run_construct_and_print_in_module), 260 ("integration/dialects/linalg/opsrun", run), 261 ("ir/affine_expr", run), 262 ("ir/affine_map", run), 263 ("ir/array_attributes", run), 264 ("ir/attributes", run), 265 ("ir/blocks", run), 266 ("ir/builtin_types", run), 267 ("ir/context_managers", run), 268 ("ir/debug", run), 269 ("ir/diagnostic_handler", run), 270 ("ir/dialects", run), 271 ("ir/exception", run), 272 ("ir/insertion_point", run), 273 ("ir/integer_set", run), 274 ("ir/location", run), 275 ("ir/module", run), 276 ("ir/operation", run), 277 ("ir/symbol_table", run), 278 ("ir/value", run), 279] 280 281TESTS_TO_SKIP = [ 282 "test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL 283 "test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL 284 "test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL 285 "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals 286 "test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL 287 # tests indirectly calling thread-unsafe llvm::raw_ostream 288 "test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream 289 "test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream 290 "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream 291 "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream 292 "test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream 293 "test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream 294 "test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream 295 "test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream 296 "test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream 297 "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream 298 "test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream 299 # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames() 300 # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947 301 "test_execution_engine__testCapsule_multi_threaded", 302 "test_execution_engine__testDumpToObjectFile_multi_threaded", 303] 304 305TESTS_TO_XFAIL = [ 306 # execution_engine tests: 307 # - ctypes related data-races: https://github.com/python/cpython/issues/127945 308 "test_execution_engine__testBF16Memref_multi_threaded", 309 "test_execution_engine__testBasicCallback_multi_threaded", 310 "test_execution_engine__testComplexMemrefAdd_multi_threaded", 311 "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded", 312 "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded", 313 "test_execution_engine__testF16MemrefAdd_multi_threaded", 314 "test_execution_engine__testF8E5M2Memref_multi_threaded", 315 "test_execution_engine__testInvokeFloatAdd_multi_threaded", 316 "test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race 317 "test_execution_engine__testMemrefAdd_multi_threaded", 318 "test_execution_engine__testRankedMemRefCallback_multi_threaded", 319 "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded", 320 "test_execution_engine__testUnrankedMemRefCallback_multi_threaded", 321 "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded", 322 # dialects tests 323 "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races 324 "test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe 325 "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation 326 "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded", 327 "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded", 328 "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded", 329 # integration tests 330 "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races 331 "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races 332 "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes 333 "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes 334 "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes 335 "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes 336 "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes 337 "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes 338 "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes 339 "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes 340] 341 342 343def add_existing_tests(test_modules, test_prefix: str = "_original_test"): 344 def decorator(test_cls): 345 this_folder = Path(__file__).parent.absolute() 346 test_cls.output_folder = tempfile.TemporaryDirectory() 347 output_folder = Path(test_cls.output_folder.name) 348 349 for test_mod_info in test_modules: 350 assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3) 351 if len(test_mod_info) == 2: 352 test_module_name, exec_fn = test_mod_info 353 test_pattern = None 354 else: 355 test_module_name, exec_fn, test_pattern = test_mod_info 356 357 src_filepath = this_folder / f"{test_module_name}.py" 358 dst_filepath = (output_folder / f"{test_module_name}.py").absolute() 359 if not dst_filepath.parent.exists(): 360 dst_filepath.parent.mkdir(parents=True) 361 copy_and_update(src_filepath, dst_filepath) 362 test_mod = import_from_path(test_module_name, dst_filepath) 363 for attr_name in dir(test_mod): 364 is_test_fn = test_pattern is None and attr_name.startswith("test") 365 is_test_fn |= test_pattern is not None and any( 366 [p in attr_name for p in test_pattern] 367 ) 368 if is_test_fn: 369 obj = getattr(test_mod, attr_name) 370 if callable(obj): 371 test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}" 372 373 def wrapped_test_fn( 374 self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs 375 ): 376 __exec_fn__(__test_fn__) 377 378 setattr(test_cls, test_name, wrapped_test_fn) 379 return test_cls 380 381 return decorator 382 383 384@contextmanager 385def _capture_output(fp): 386 # Inspired from jax test_utils.py capture_stderr method 387 # ``None`` means nothing has not been captured yet. 388 captured = None 389 390 def get_output() -> str: 391 if captured is None: 392 raise ValueError("get_output() called while the context is active.") 393 return captured 394 395 with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f: 396 original_fd = os.dup(fp.fileno()) 397 os.dup2(f.fileno(), fp.fileno()) 398 try: 399 yield get_output 400 finally: 401 # Python also has its own buffers, make sure everything is flushed. 402 fp.flush() 403 os.fsync(fp.fileno()) 404 f.seek(0) 405 captured = f.read() 406 os.dup2(original_fd, fp.fileno()) 407 408 409capture_stdout = partial(_capture_output, sys.stdout) 410capture_stderr = partial(_capture_output, sys.stderr) 411 412 413def multi_threaded( 414 num_workers: int, 415 num_runs: int = 5, 416 skip_tests: Optional[List[str]] = None, 417 xfail_tests: Optional[List[str]] = None, 418 test_prefix: str = "_original_test", 419 multithreaded_test_postfix: str = "_multi_threaded", 420): 421 """Decorator that runs a test in a multi-threaded environment.""" 422 423 def decorator(test_cls): 424 for name, test_fn in test_cls.__dict__.copy().items(): 425 if not (name.startswith(test_prefix) and callable(test_fn)): 426 continue 427 428 name = f"test{name[len(test_prefix):]}" 429 if skip_tests is not None: 430 if any( 431 test_name.replace(multithreaded_test_postfix, "") in name 432 for test_name in skip_tests 433 ): 434 continue 435 436 def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs): 437 with capture_stdout(), capture_stderr() as get_output: 438 barrier = threading.Barrier(num_workers) 439 440 def closure(): 441 barrier.wait() 442 for _ in range(num_runs): 443 __test_fn__(self, *args, **kwargs) 444 445 with concurrent.futures.ThreadPoolExecutor( 446 max_workers=num_workers 447 ) as executor: 448 futures = [] 449 for _ in range(num_workers): 450 futures.append(executor.submit(closure)) 451 # We should call future.result() to re-raise an exception if test has 452 # failed 453 assert len(list(f.result() for f in futures)) == num_workers 454 455 gc.collect() 456 assert Context._get_live_count() == 0 457 458 captured = get_output() 459 if len(captured) > 0 and "ThreadSanitizer" in captured: 460 raise RuntimeError( 461 f"ThreadSanitizer reported warnings:\n{captured}" 462 ) 463 464 test_new_name = f"{name}{multithreaded_test_postfix}" 465 if xfail_tests is not None and test_new_name in xfail_tests: 466 multi_threaded_test_fn = unittest.expectedFailure( 467 multi_threaded_test_fn 468 ) 469 470 setattr(test_cls, test_new_name, multi_threaded_test_fn) 471 472 return test_cls 473 474 return decorator 475 476 477@multi_threaded( 478 num_workers=10, 479 num_runs=20, 480 skip_tests=TESTS_TO_SKIP, 481 xfail_tests=TESTS_TO_XFAIL, 482) 483@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test") 484class TestAllMultiThreaded(unittest.TestCase): 485 @classmethod 486 def tearDownClass(cls): 487 if hasattr(cls, "output_folder"): 488 cls.output_folder.cleanup() 489 490 def _original_test_create_context(self): 491 with Context() as ctx: 492 print(ctx._get_live_count()) 493 print(ctx._get_live_module_count()) 494 print(ctx._get_live_operation_count()) 495 print(ctx._get_live_operation_objects()) 496 print(ctx._get_context_again() is ctx) 497 print(ctx._clear_live_operations()) 498 499 def _original_test_create_module_with_consts(self): 500 py_values = [123, 234, 345] 501 with Context() as ctx: 502 module = Module.create(loc=Location.file("foo.txt", 0, 0)) 503 504 dtype = IntegerType.get_signless(64) 505 with InsertionPoint(module.body), Location.name("a"): 506 arith.constant(dtype, py_values[0]) 507 508 with InsertionPoint(module.body), Location.name("b"): 509 arith.constant(dtype, py_values[1]) 510 511 with InsertionPoint(module.body), Location.name("c"): 512 arith.constant(dtype, py_values[2]) 513 514 515if __name__ == "__main__": 516 # Do not run the tests on CPython with GIL 517 if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled(): 518 unittest.main() 519