diff --git a/comtypes/test/test_logutil.py b/comtypes/test/test_logutil.py index 04f3bd32..4d116daf 100644 --- a/comtypes/test/test_logutil.py +++ b/comtypes/test/test_logutil.py @@ -172,50 +172,62 @@ def open_dbwin_debug_channels() -> Iterator[tuple[int, int, int]]: yield (h_buffer_ready, h_data_ready, p_view) -@contextlib.contextmanager -def capture_debug_strings(ready: threading.Event, *, interval: int) -> Iterator[Queue]: - """Context manager to capture debug strings emitted via `OutputDebugString`. - Spawns a listener thread to monitor the debug channels. - """ - captured = Queue() - finished = threading.Event() - - def _listener( - q: Queue, rdy: threading.Event, fin: threading.Event, pid: int - ) -> None: - # Create/open named events and file mapping for interprocess communication. - # These objects are part of the Windows Debugging API contract. - with open_dbwin_debug_channels() as (h_buffer_ready, h_data_ready, p_view): - rdy.set() # Signal to the main thread that listener is ready. - while not fin.is_set(): # Loop until the main thread signals to finish. - _SetEvent(h_buffer_ready) # Signal readiness to `OutputDebugString`. - # Wait for `OutputDebugString` to signal that data is ready. - if _WaitForSingleObject(h_data_ready, interval) == WAIT_OBJECT_0: - # Debug string buffer format: [4 bytes: PID][N bytes: string]. - # Check if the process ID in the buffer matches the current PID. - if ctypes.cast(p_view, POINTER(DWORD)).contents.value == pid: - # Extract the null-terminated string, skipping the PID, - # and put it into the queue. - q.put(ctypes.string_at(p_view + 4).strip(b"\x00")) +def _listen_on_dbwin_channel( + interval_ms: int, + messages: Queue, + ready: threading.Event, + stop: threading.Event, + pid: int, +) -> None: + # Create/open named events and file mapping for interprocess communication. + # These objects are part of the Windows Debugging API contract. + with open_dbwin_debug_channels() as (h_buffer_ready, h_data_ready, p_view): + ready.set() # Signal to the main thread that listener is ready. + while not stop.is_set(): # Loop until the main thread signals to finish. + _SetEvent(h_buffer_ready) # Signal readiness to `OutputDebugString`. + # Wait for `OutputDebugString` to signal that data is ready. + if _WaitForSingleObject(h_data_ready, interval_ms) == WAIT_OBJECT_0: + # Debug string buffer format: [4 bytes: PID][N bytes: string]. + # Check if the process ID in the buffer matches the current PID. + if ctypes.cast(p_view, POINTER(DWORD)).contents.value == pid: + # Extract the null-terminated string, skipping the PID, + # and put it into the queue. + messages.put(ctypes.string_at(p_view + 4).strip(b"\x00")) + +@contextlib.contextmanager +def _run_dbwin_listener(ready: threading.Event, interval_ms: int) -> Iterator[Queue]: + messages = Queue() + stop = threading.Event() th = threading.Thread( - target=_listener, - args=(captured, ready, finished, _GetCurrentProcessId()), + target=_listen_on_dbwin_channel, + args=(interval_ms, messages, ready, stop, _GetCurrentProcessId()), daemon=True, ) th.start() try: - yield captured + yield messages finally: - finished.set() + stop.set() th.join() +@contextlib.contextmanager +def capture_debug_strings(*, timeout: float, interval: float) -> Iterator[Queue]: + """Context manager to capture debug strings emitted via `OutputDebugString`. + Spawns a listener thread to monitor the debug channels. + + Parameters are floats in seconds. + """ + ready = threading.Event() + with _run_dbwin_listener(ready, int(interval * 1000)) as messages: + ready.wait(timeout=timeout) # Wait for the listener to be ready + yield messages + + class Test_OutputDebugStringW(ut.TestCase): def test(self): - ready = threading.Event() - with capture_debug_strings(ready, interval=100) as cap: - ready.wait(timeout=5) # Wait for the listener to be ready + with capture_debug_strings(timeout=5, interval=0.1) as cap: OutputDebugStringW("hello world") OutputDebugStringW("test message") self.assertEqual(cap.get(), b"hello world") @@ -224,15 +236,20 @@ def test(self): class Test_NTDebugHandler(ut.TestCase): def test_emit(self): - ready = threading.Event() handler = NTDebugHandler() - logger = logging.getLogger("test_ntdebug_handler") + # Direct `Logger()` instantiation for test isolation: bypasses global + # registration and prevents any side effects / cross-test pollution. + # (The official 'Loggers should NEVER be instantiated directly' rule + # targets production code where hierarchy and propagation matter; + # here we want neither.) + # https://docs.python.org/3/library/logging.html#logger-objects + logger = logging.Logger("test_ntdebug_handler") # Clear existing handlers to prevent interference from other tests + logger.propagate = False logger.handlers = [] logger.addHandler(handler) logger.setLevel(logging.INFO) - with capture_debug_strings(ready, interval=100) as cap: - ready.wait(timeout=5) # Wait for the listener to be ready + with capture_debug_strings(timeout=5, interval=0.1) as cap: msg = "This is a test message from NTDebugHandler." logger.info(msg) logger.removeHandler(handler)