Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 81 additions & 35 deletions kis_agent/websocket/ws_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def __init__(
self._subscription_results: Dict[str, bool] = {} # True=성공, False=실패
self._subscription_errors: Dict[str, str] = {} # 실패 시 에러 메시지

# Fire-and-forget 백그라운드 태스크 참조 보관.
# asyncio.create_task의 반환값을 어디에도 저장하지 않으면 GC에 의해
# 태스크가 사라질 수 있다 (Python docs 권장 패턴). disconnect 시
# 미완료 태스크를 일괄 cancel하기 위해서도 필요.
self._background_tasks: Set[asyncio.Task] = set()

# 통계
self.stats = {
"messages_received": 0,
Expand All @@ -154,6 +160,18 @@ def _ws_closed(self) -> bool:
# websockets v14+: close_code가 설정되면 연결이 닫힌 것
return self.ws.close_code is not None

def _track_task(self, task: asyncio.Task) -> asyncio.Task:
"""백그라운드 태스크 참조 보관 + 완료 시 자동 제거.

``asyncio.create_task``의 반환값을 어디에도 저장하지 않으면 GC에 의해
태스크가 weakly-referenced 상태로 사라질 수 있다 (CPython 3.10+에서
문서화된 동작). 이 헬퍼는 set에 강한 참조를 추가하고, 태스크 완료 시
callback으로 set에서 제거하여 메모리 누수도 막는다.
"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task

def update_approval_key(self, new_approval_key: str) -> None:
"""
approval_key 갱신 (토큰 재발급 시 사용)
Expand Down Expand Up @@ -227,7 +245,7 @@ def subscribe(

# 연결되어 있으면 비동기 태스크 생성 (결과 추적)
if self.connected and self.ws and not self._ws_closed():
task = asyncio.create_task(self._send_subscription(subscription))
task = self._track_task(asyncio.create_task(self._send_subscription(subscription)))
# 태스크 완료 시 실패 로깅을 위한 콜백 추가
task.add_done_callback(lambda t: self._on_subscription_task_done(t, sub_id))

Expand Down Expand Up @@ -302,7 +320,7 @@ def unsubscribe(self, sub_id: str):

# 구독 해제 메시지 전송 (연결 상태 상세 검증)
if self.connected and self.ws and not self._ws_closed():
asyncio.create_task(self._send_unsubscription(subscription))
self._track_task(asyncio.create_task(self._send_unsubscription(subscription)))

# 구독 정보 삭제
del self.subscriptions[sub_id]
Expand Down Expand Up @@ -430,8 +448,10 @@ async def _send_subscription(
logger.error(f"구독 요청 오류: {sub_id} - {e}")

finally:
# 정리
# 정리 (재시도 사이 stale state 사용 방지)
self._pending_subscriptions.pop(sub_id, None)
self._subscription_results.pop(sub_id, None)
self._subscription_errors.pop(sub_id, None)

# 재시도 전 대기 (지수 백오프)
if attempt < max_retries - 1:
Expand Down Expand Up @@ -471,21 +491,23 @@ async def _send_unsubscription(self, subscription: Subscription):

async def _subscribe_all(self) -> dict:
"""
모든 구독 요청 전송 (순차 처리 + 딜레이로 안정성 개선)
모든 구독 요청을 순차 전송한다.

KIS 서버에서 빠른 구독 요청을 처리하지 못하는 문제를 해결하기 위해
순차 처리와 딜레이를 적용합니다.
``_send_subscription``은 각 구독마다 응답 이벤트를 기다리므로 자연스럽게
직렬화되어 KIS 서버의 rate limit을 준수한다. 응답 수신 직후 다음 요청을
보내므로 추가 sleep이 불필요하다 (이전엔 0.5초/구독 sleep으로 50종목
구독 시 ~25초가 추가로 소비됐다).

Returns:
dict: 구독 결과 {"success": [...], "failed": [...]}
"""
results = {"success": [], "failed": []}
total = len(self.subscriptions)
subs = list(self.subscriptions.values())
total = len(subs)

logger.info(f"구독 시작: 총 {total}개 종목 (순차 처리)")

# 순차 처리로 구독 요청 (연결 안정성 향상)
for idx, subscription in enumerate(self.subscriptions.values()):
for idx, subscription in enumerate(subs):
sub_id = f"{subscription.sub_type.value}_{subscription.key}"

# 연결 상태 확인
Expand All @@ -494,8 +516,7 @@ async def _subscribe_all(self) -> dict:
f"구독 중단 - 연결 끊김 (성공: {len(results['success'])}, 남은: {total - idx})"
)
# 남은 구독들을 실패로 처리
remaining_subs = list(self.subscriptions.values())[idx:]
for remaining in remaining_subs:
for remaining in subs[idx:]:
remaining_id = f"{remaining.sub_type.value}_{remaining.key}"
results["failed"].append(remaining_id)
break
Expand All @@ -513,10 +534,6 @@ async def _subscribe_all(self) -> dict:
f"구독 진행: {idx + 1}/{total} (성공: {len(results['success'])})"
)

# 구독 사이 딜레이 (0.5초) - KIS 서버 rate limit 준수
if idx < total - 1:
await asyncio.sleep(0.5)

except Exception as e:
logger.error(f"구독 요청 중 예외 발생: {sub_id} - {e}")
results["failed"].append(sub_id)
Expand All @@ -530,16 +547,21 @@ async def _subscribe_all(self) -> dict:

return results

def _parse_message(self, data: str) -> tuple:
"""
메시지 파싱
def _parse_message(self, data: str, json_data: Optional[dict] = None) -> tuple:
"""메시지 파싱.

Args:
data: 원본 raw 메시지.
json_data: 호출자가 이미 ``json.loads(data)``를 수행했다면 결과를
전달해 이중 파싱을 피한다.

Returns:
(tr_id, tr_key, parsed_data)
"""
if data.startswith("{"):
# JSON 메시지
json_data = json.loads(data)
# JSON 메시지 (이중 파싱 방지)
if json_data is None:
json_data = json.loads(data)
header = json_data.get("header", {})
body = json_data.get("body", {})

Expand Down Expand Up @@ -649,16 +671,19 @@ async def _handle_message(self, data: str):
if "PINGPONG" in data:
return

# JSON 메시지인 경우 구독 응답 먼저 확인
# JSON 메시지인 경우 한 번만 파싱하고 구독 응답 먼저 확인.
# 일반 데이터 처리를 위한 _parse_message에도 같은 json_data를 넘겨
# 이중 파싱을 막는다.
preparsed_json = None
if data.startswith("{"):
try:
json_data = json.loads(data)
if self._handle_subscription_response(json_data):
preparsed_json = json.loads(data)
if self._handle_subscription_response(preparsed_json):
return # 구독 응답 메시지는 여기서 처리 완료
except json.JSONDecodeError:
json_data = None # JSON 아닌 경우 일반 메시지로 계속 처리
preparsed_json = None

tr_id, tr_key, parsed_data = self._parse_message(data)
tr_id, tr_key, parsed_data = self._parse_message(data, preparsed_json)

if not tr_id:
return
Expand Down Expand Up @@ -712,11 +737,24 @@ async def _handle_message(self, data: str):
self.stats["errors"] += 1

async def _call_handler(self, handler: Callable, data: Any, metadata: Dict):
"""핸들러 호출"""
if asyncio.iscoroutinefunction(handler):
await handler(data, metadata)
else:
handler(data, metadata)
"""핸들러 호출.

- 코루틴 핸들러는 직접 await.
- 동기 핸들러는 ``asyncio.to_thread()``로 격리 실행. 이렇게 하면
무거운 CPU/IO 작업이 있는 동기 핸들러도 수신 루프(_receive_loop)를
블록하지 않아 ping/pong 응답 지연으로 인한 강제 재연결이 발생하지 않는다.
- 핸들러 자체에서 발생한 예외는 잡아 로깅만 — 한 핸들러의 오류가
다른 메시지 처리를 막지 않도록.
"""
try:
if asyncio.iscoroutinefunction(handler):
await handler(data, metadata)
else:
# to_thread로 격리해 수신 루프 블록 방지
await asyncio.to_thread(handler, data, metadata)
except Exception as e:
logger.error(f"핸들러 실행 오류 ({getattr(handler, '__name__', handler)}): {e}")
self.stats["errors"] += 1

async def _receive_loop(self, websocket) -> str:
"""
Expand Down Expand Up @@ -965,18 +1003,26 @@ async def connect(self):
await asyncio.sleep(backoff)

async def disconnect(self):
"""
웹소켓 연결을 종료합니다.
"""웹소켓 연결을 종료한다.

auto_reconnect를 False로 설정하고 현재 연결을 닫습니다.
``auto_reconnect``를 False로 설정하고 현재 연결을 닫는다. 동시에
진행 중인 fire-and-forget 백그라운드 태스크(_send_subscription
/ _send_unsubscription)도 cancel하여 누수를 방지한다.
"""
"""웹소켓 연결 종료"""
self.auto_reconnect = False
if self.ws:
await self.ws.close()
self.ws = None
self.connected = False
logger.info("웹소켓 연결 종료")

# 백그라운드 태스크 정리
pending = [t for t in self._background_tasks if not t.done()]
for task in pending:
task.cancel()
if pending:
await asyncio.gather(*pending, return_exceptions=True)

logger.info(f"웹소켓 연결 종료 (백그라운드 태스크 {len(pending)}개 정리)")

def get_stats(self) -> Dict[str, Any]:
"""
Expand Down
Loading
Loading