From f69eaa424c4a0c35943287670527c0e14804cc0d Mon Sep 17 00:00:00 2001 From: Heewon Oh Date: Thu, 4 Jun 2026 13:33:30 +0900 Subject: [PATCH] perf(ws_agent): optimize subscription speed, isolate sync handlers, fix task leaks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WSAgent 4가지 최적화·안정성 개선. 1. 순차 구독 속도 N배 개선 - _subscribe_all에서 구독 사이 0.5초 sleep 제거 - _send_subscription이 응답 대기로 자연 직렬화되므로 추가 sleep 불필요 - 50 종목 구독 시 ~25초 단축 (이전: 0.5s × N 추가 소비) 2. Sync 핸들러 격리 — 수신 루프 블록 방지 - _call_handler에서 sync 핸들러를 asyncio.to_thread()로 격리 - 이전: 무거운 sync 핸들러가 _receive_loop 자체를 블록 → ping/pong 응답 지연 → 강제 재연결 발생 가능 - 핸들러 예외도 _call_handler가 잡아 stats["errors"] 증가만 시키고 다른 메시지 처리는 계속 진행 3. Background task 참조 보관 + 정리 - subscribe()/unsubscribe()의 fire-and-forget asyncio.create_task가 반환 task를 어디에도 저장하지 않아 GC로 사라질 위험 (Python 3.10+ 문서화된 동작) - _background_tasks set + _track_task() 헬퍼로 강한 참조 보관, 완료 시 done_callback으로 자동 제거 - disconnect() 시 미완료 task 일괄 cancel + gather 4. 구독 응답 state 정리 누수 fix + JSON 이중 파싱 제거 - _send_subscription의 finally에서 _subscription_results/_errors도 pop (이전엔 _pending만 정리되어 재시도 시 stale 데이터 사용 위험) - _handle_message에서 JSON 파싱한 결과를 _parse_message에 넘겨 이중 json.loads() 호출 제거 검증: - 신규 단위 테스트 9개 (tests/unit/test_ws_agent_optimizations.py): * _subscribe_all 5개 구독이 0.3초 미만에 완료 * sync 핸들러가 메인 스레드와 다른 스레드에서 실행됨 (격리) * async 핸들러는 이벤트 루프에서 직접 await * 핸들러 예외가 _call_handler에서 잡혀 전파되지 않음 * _track_task가 done_callback으로 자동 제거됨 * disconnect가 미완료 task 모두 cancel * _send_subscription 실패 후 pending/results/errors 모두 정리 * _parse_message가 preparsed json_data 재사용 (이중 파싱 회피) - 기존 ws 단위 테스트 144/145 passed (1 pre-existing failure: test_fatal_ error_patterns_exist — 본 변경과 무관) - ruff: All checks passed --- kis_agent/websocket/ws_agent.py | 116 +++++++++---- tests/unit/test_ws_agent_optimizations.py | 198 ++++++++++++++++++++++ 2 files changed, 279 insertions(+), 35 deletions(-) create mode 100644 tests/unit/test_ws_agent_optimizations.py diff --git a/kis_agent/websocket/ws_agent.py b/kis_agent/websocket/ws_agent.py index db68d43..cfd22b6 100644 --- a/kis_agent/websocket/ws_agent.py +++ b/kis_agent/websocket/ws_agent.py @@ -138,6 +138,12 @@ def __init__( self._subscription_results: Dict[str, bool] = {} # True=성공, False=실패 self._subscription_errors: Dict[str, str] = {} # 실패 시 에러 메시지 + # Fire-and-forget 백그라운드 태스크 참조 보관. + # asyncio.create_task의 반환값을 어디에도 저장하지 않으면 GC에 의해 + # 태스크가 사라질 수 있다 (Python docs 권장 패턴). disconnect 시 + # 미완료 태스크를 일괄 cancel하기 위해서도 필요. + self._background_tasks: Set[asyncio.Task] = set() + # 통계 self.stats = { "messages_received": 0, @@ -154,6 +160,18 @@ def _ws_closed(self) -> bool: # websockets v14+: close_code가 설정되면 연결이 닫힌 것 return self.ws.close_code is not None + def _track_task(self, task: asyncio.Task) -> asyncio.Task: + """백그라운드 태스크 참조 보관 + 완료 시 자동 제거. + + ``asyncio.create_task``의 반환값을 어디에도 저장하지 않으면 GC에 의해 + 태스크가 weakly-referenced 상태로 사라질 수 있다 (CPython 3.10+에서 + 문서화된 동작). 이 헬퍼는 set에 강한 참조를 추가하고, 태스크 완료 시 + callback으로 set에서 제거하여 메모리 누수도 막는다. + """ + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task + def update_approval_key(self, new_approval_key: str) -> None: """ approval_key 갱신 (토큰 재발급 시 사용) @@ -227,7 +245,7 @@ def subscribe( # 연결되어 있으면 비동기 태스크 생성 (결과 추적) if self.connected and self.ws and not self._ws_closed(): - task = asyncio.create_task(self._send_subscription(subscription)) + task = self._track_task(asyncio.create_task(self._send_subscription(subscription))) # 태스크 완료 시 실패 로깅을 위한 콜백 추가 task.add_done_callback(lambda t: self._on_subscription_task_done(t, sub_id)) @@ -302,7 +320,7 @@ def unsubscribe(self, sub_id: str): # 구독 해제 메시지 전송 (연결 상태 상세 검증) if self.connected and self.ws and not self._ws_closed(): - asyncio.create_task(self._send_unsubscription(subscription)) + self._track_task(asyncio.create_task(self._send_unsubscription(subscription))) # 구독 정보 삭제 del self.subscriptions[sub_id] @@ -430,8 +448,10 @@ async def _send_subscription( logger.error(f"구독 요청 오류: {sub_id} - {e}") finally: - # 정리 + # 정리 (재시도 사이 stale state 사용 방지) self._pending_subscriptions.pop(sub_id, None) + self._subscription_results.pop(sub_id, None) + self._subscription_errors.pop(sub_id, None) # 재시도 전 대기 (지수 백오프) if attempt < max_retries - 1: @@ -471,21 +491,23 @@ async def _send_unsubscription(self, subscription: Subscription): async def _subscribe_all(self) -> dict: """ - 모든 구독 요청 전송 (순차 처리 + 딜레이로 안정성 개선) + 모든 구독 요청을 순차 전송한다. - KIS 서버에서 빠른 구독 요청을 처리하지 못하는 문제를 해결하기 위해 - 순차 처리와 딜레이를 적용합니다. + ``_send_subscription``은 각 구독마다 응답 이벤트를 기다리므로 자연스럽게 + 직렬화되어 KIS 서버의 rate limit을 준수한다. 응답 수신 직후 다음 요청을 + 보내므로 추가 sleep이 불필요하다 (이전엔 0.5초/구독 sleep으로 50종목 + 구독 시 ~25초가 추가로 소비됐다). Returns: dict: 구독 결과 {"success": [...], "failed": [...]} """ results = {"success": [], "failed": []} - total = len(self.subscriptions) + subs = list(self.subscriptions.values()) + total = len(subs) logger.info(f"구독 시작: 총 {total}개 종목 (순차 처리)") - # 순차 처리로 구독 요청 (연결 안정성 향상) - for idx, subscription in enumerate(self.subscriptions.values()): + for idx, subscription in enumerate(subs): sub_id = f"{subscription.sub_type.value}_{subscription.key}" # 연결 상태 확인 @@ -494,8 +516,7 @@ async def _subscribe_all(self) -> dict: f"구독 중단 - 연결 끊김 (성공: {len(results['success'])}, 남은: {total - idx})" ) # 남은 구독들을 실패로 처리 - remaining_subs = list(self.subscriptions.values())[idx:] - for remaining in remaining_subs: + for remaining in subs[idx:]: remaining_id = f"{remaining.sub_type.value}_{remaining.key}" results["failed"].append(remaining_id) break @@ -513,10 +534,6 @@ async def _subscribe_all(self) -> dict: f"구독 진행: {idx + 1}/{total} (성공: {len(results['success'])})" ) - # 구독 사이 딜레이 (0.5초) - KIS 서버 rate limit 준수 - if idx < total - 1: - await asyncio.sleep(0.5) - except Exception as e: logger.error(f"구독 요청 중 예외 발생: {sub_id} - {e}") results["failed"].append(sub_id) @@ -530,16 +547,21 @@ async def _subscribe_all(self) -> dict: return results - def _parse_message(self, data: str) -> tuple: - """ - 메시지 파싱 + def _parse_message(self, data: str, json_data: Optional[dict] = None) -> tuple: + """메시지 파싱. + + Args: + data: 원본 raw 메시지. + json_data: 호출자가 이미 ``json.loads(data)``를 수행했다면 결과를 + 전달해 이중 파싱을 피한다. Returns: (tr_id, tr_key, parsed_data) """ if data.startswith("{"): - # JSON 메시지 - json_data = json.loads(data) + # JSON 메시지 (이중 파싱 방지) + if json_data is None: + json_data = json.loads(data) header = json_data.get("header", {}) body = json_data.get("body", {}) @@ -649,16 +671,19 @@ async def _handle_message(self, data: str): if "PINGPONG" in data: return - # JSON 메시지인 경우 구독 응답 먼저 확인 + # JSON 메시지인 경우 한 번만 파싱하고 구독 응답 먼저 확인. + # 일반 데이터 처리를 위한 _parse_message에도 같은 json_data를 넘겨 + # 이중 파싱을 막는다. + preparsed_json = None if data.startswith("{"): try: - json_data = json.loads(data) - if self._handle_subscription_response(json_data): + preparsed_json = json.loads(data) + if self._handle_subscription_response(preparsed_json): return # 구독 응답 메시지는 여기서 처리 완료 except json.JSONDecodeError: - json_data = None # JSON 아닌 경우 일반 메시지로 계속 처리 + preparsed_json = None - tr_id, tr_key, parsed_data = self._parse_message(data) + tr_id, tr_key, parsed_data = self._parse_message(data, preparsed_json) if not tr_id: return @@ -712,11 +737,24 @@ async def _handle_message(self, data: str): self.stats["errors"] += 1 async def _call_handler(self, handler: Callable, data: Any, metadata: Dict): - """핸들러 호출""" - if asyncio.iscoroutinefunction(handler): - await handler(data, metadata) - else: - handler(data, metadata) + """핸들러 호출. + + - 코루틴 핸들러는 직접 await. + - 동기 핸들러는 ``asyncio.to_thread()``로 격리 실행. 이렇게 하면 + 무거운 CPU/IO 작업이 있는 동기 핸들러도 수신 루프(_receive_loop)를 + 블록하지 않아 ping/pong 응답 지연으로 인한 강제 재연결이 발생하지 않는다. + - 핸들러 자체에서 발생한 예외는 잡아 로깅만 — 한 핸들러의 오류가 + 다른 메시지 처리를 막지 않도록. + """ + try: + if asyncio.iscoroutinefunction(handler): + await handler(data, metadata) + else: + # to_thread로 격리해 수신 루프 블록 방지 + await asyncio.to_thread(handler, data, metadata) + except Exception as e: + logger.error(f"핸들러 실행 오류 ({getattr(handler, '__name__', handler)}): {e}") + self.stats["errors"] += 1 async def _receive_loop(self, websocket) -> str: """ @@ -965,18 +1003,26 @@ async def connect(self): await asyncio.sleep(backoff) async def disconnect(self): - """ - 웹소켓 연결을 종료합니다. + """웹소켓 연결을 종료한다. - auto_reconnect를 False로 설정하고 현재 연결을 닫습니다. + ``auto_reconnect``를 False로 설정하고 현재 연결을 닫는다. 동시에 + 진행 중인 fire-and-forget 백그라운드 태스크(_send_subscription + / _send_unsubscription)도 cancel하여 누수를 방지한다. """ - """웹소켓 연결 종료""" self.auto_reconnect = False if self.ws: await self.ws.close() self.ws = None self.connected = False - logger.info("웹소켓 연결 종료") + + # 백그라운드 태스크 정리 + pending = [t for t in self._background_tasks if not t.done()] + for task in pending: + task.cancel() + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + logger.info(f"웹소켓 연결 종료 (백그라운드 태스크 {len(pending)}개 정리)") def get_stats(self) -> Dict[str, Any]: """ diff --git a/tests/unit/test_ws_agent_optimizations.py b/tests/unit/test_ws_agent_optimizations.py new file mode 100644 index 0000000..fda9d50 --- /dev/null +++ b/tests/unit/test_ws_agent_optimizations.py @@ -0,0 +1,198 @@ +"""WSAgent 최적화 검증. + +PR perf/ws-agent-optimizations에서 도입된 4가지 변경 검증: +1. _subscribe_all에서 구독 간 0.5초 sleep 제거 (응답 대기로 자연 직렬화) +2. _call_handler가 sync 핸들러를 asyncio.to_thread()로 격리 +3. fire-and-forget 백그라운드 태스크가 _background_tasks에 추적되어 GC 회피 +4. _send_subscription의 결과 dict 누수 fix +""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from kis_agent.websocket.ws_agent import WSAgent +from kis_agent.websocket.ws_types import Subscription, SubscriptionType + + +@pytest.fixture +def agent(): + """기본 WSAgent 인스턴스 (실제 연결 없음).""" + return WSAgent(approval_key="test_key") + + +class TestSubscribeAllNoExtraSleep: + @pytest.mark.asyncio + async def test_subscribe_all_does_not_sleep_between_subs(self, agent): + """_subscribe_all 자체에는 sub 간 추가 sleep이 없다. + + 실제 KIS 응답 대기는 _send_subscription에서 처리되므로 mocking으로 즉시 + 성공시키면 N개 구독이 거의 즉시 완료되어야 한다 (이전 구현은 0.5s/sub). + """ + # 5개 구독 등록 (실제 send는 mock) + for code in ("005930", "000660", "035420", "035720", "051910"): + sub = Subscription( + sub_type=SubscriptionType.STOCK_TRADE, + key=code, + handler=None, + metadata={}, + ) + agent.subscriptions[f"H0STCNT0_{code}"] = sub + + # 연결 시뮬레이션 + agent.ws = MagicMock() + agent.ws.close_code = None + agent.connected = True + + # _send_subscription을 즉시 성공으로 mock + async def fake_send(subscription, max_retries=3, timeout=60.0): + await asyncio.sleep(0) # 단순 yield + return True + + agent._send_subscription = fake_send + + start = time.monotonic() + result = await agent._subscribe_all() + elapsed = time.monotonic() - start + + assert len(result["success"]) == 5 + # 이전 구현은 5개 = 4×0.5s = 2초+ 소비 + assert elapsed < 0.3, f"sub 간 sleep이 남아있는 듯: {elapsed:.3f}s" + + +class TestSyncHandlerIsolation: + @pytest.mark.asyncio + async def test_sync_handler_runs_in_thread(self, agent): + """동기 핸들러가 asyncio.to_thread()로 격리되어 이벤트 루프를 블록하지 않는다.""" + main_loop_thread_id = None + handler_thread_id = None + + # 이벤트 루프 스레드 ID 기록 (단순히 현재 스레드) + import threading + + main_loop_thread_id = threading.get_ident() + + def sync_handler(data, metadata): + nonlocal handler_thread_id + handler_thread_id = threading.get_ident() + + await agent._call_handler(sync_handler, {"price": 100}, {}) + + assert handler_thread_id is not None + # to_thread로 실행되면 메인 스레드와 다른 스레드에서 실행됨 + assert handler_thread_id != main_loop_thread_id, ( + "sync 핸들러가 메인 이벤트 루프 스레드에서 실행됨 — 격리 안 됨" + ) + + @pytest.mark.asyncio + async def test_async_handler_runs_in_loop(self, agent): + """async 핸들러는 이벤트 루프에서 직접 await된다 (to_thread 사용 안 함).""" + received = [] + + async def async_handler(data, metadata): + received.append(data) + + await agent._call_handler(async_handler, {"price": 200}, {"meta": "x"}) + assert received == [{"price": 200}] + + @pytest.mark.asyncio + async def test_handler_exception_caught_and_counted(self, agent): + """핸들러가 예외를 발생시켜도 _call_handler가 잡고 stats["errors"] 증가.""" + initial_errors = agent.stats["errors"] + + def bad_handler(data, metadata): + raise RuntimeError("intentional") + + # 예외가 전파되지 않아야 함 + await agent._call_handler(bad_handler, {}, {}) + assert agent.stats["errors"] == initial_errors + 1 + + +class TestBackgroundTaskTracking: + @pytest.mark.asyncio + async def test_track_task_keeps_reference(self, agent): + """_track_task가 task를 set에 추가하고 완료 시 자동 제거한다.""" + + async def quick(): + return 42 + + task = agent._track_task(asyncio.create_task(quick())) + assert task in agent._background_tasks + + await task + # done_callback이 즉시 동기 실행되지 않을 수 있어 한 번 양보 + await asyncio.sleep(0) + assert task not in agent._background_tasks + + @pytest.mark.asyncio + async def test_disconnect_cancels_pending_background_tasks(self, agent): + """disconnect 시 미완료 백그라운드 태스크가 모두 cancel된다.""" + + async def long_task(): + await asyncio.sleep(10) + + # 백그라운드 태스크 3개 시작 + for _ in range(3): + agent._track_task(asyncio.create_task(long_task())) + + assert len(agent._background_tasks) == 3 + snapshot = list(agent._background_tasks) + + await agent.disconnect() + + # 모두 cancel된 상태 + for t in snapshot: + assert t.done() + assert t.cancelled() + + +class TestSubscriptionStateCleanup: + @pytest.mark.asyncio + async def test_send_subscription_clears_all_state_after_failure(self, agent): + """_send_subscription 실패 후 pending/results/errors 모두 정리된다.""" + sub = Subscription( + sub_type=SubscriptionType.STOCK_TRADE, + key="005930", + handler=None, + metadata={}, + ) + sub_id = "H0STCNT0_005930" + + # ws mock — send는 호출되지만 응답은 안 옴 → timeout + agent.ws = MagicMock() + agent.ws.close_code = None + agent.ws.send = AsyncMock() + + # 짧은 timeout으로 즉시 실패 + result = await agent._send_subscription(sub, max_retries=1, timeout=0.05) + assert result is False + + # 모든 상태 정리 확인 + assert sub_id not in agent._pending_subscriptions + assert sub_id not in agent._subscription_results + assert sub_id not in agent._subscription_errors + + +class TestJSONDoubleParsingAvoided: + def test_parse_message_accepts_preparsed_json(self, agent): + """이미 파싱한 json_data를 넘기면 _parse_message가 다시 파싱하지 않는다.""" + raw = '{"header": {"tr_id": "H0STCNT0", "tr_key": "005930"}, "body": {"output": {}}}' + preparsed = { + "header": {"tr_id": "H0STCNT0", "tr_key": "005930"}, + "body": {"output": {}}, + } + tr_id, tr_key, data = agent._parse_message(raw, json_data=preparsed) + assert tr_id == "H0STCNT0" + assert tr_key == "005930" + # 동일 객체 사용 검증 (이중 파싱이면 새 dict) + assert data is preparsed + + def test_parse_message_parses_when_no_preparsed(self, agent): + """json_data가 없으면 _parse_message가 자체적으로 파싱한다 (역호환).""" + raw = '{"header": {"tr_id": "H0STCNT0", "tr_key": "005930"}, "body": {}}' + tr_id, tr_key, data = agent._parse_message(raw) + assert tr_id == "H0STCNT0" + assert tr_key == "005930" + assert isinstance(data, dict)