From 663e46de0da35a09358c1a16fb23e7c14d70c93e Mon Sep 17 00:00:00 2001 From: shindalsoo Date: Sat, 10 Jan 2026 01:53:33 +0900 Subject: [PATCH] feat: fix multi-turn hitl interrupt detection and state management --- src/agent_server/api/runs.py | 49 ++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/agent_server/api/runs.py b/src/agent_server/api/runs.py index 655c014..233d960 100644 --- a/src/agent_server/api/runs.py +++ b/src/agent_server/api/runs.py @@ -288,6 +288,7 @@ async def create_run( - 스트리밍이 필요한 경우 create_and_stream_run을 사용하세요 """ + print(f"[create_run] request for thread_id={thread_id}: {request.model_dump()}") # resume 명령 요구사항을 조기에 검증 if request.command and request.command.get("resume") is not None: # 스레드가 존재하고 중단된 상태인지 확인 @@ -302,12 +303,17 @@ async def create_run( # LangGraph 서비스 가져오기 langgraph_service = get_langgraph_service() - print( - f"create_run: scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}" - ) - print( - f"[create_run] scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}" - ) + print(f"[create_run] Scheduling background task run_id={run_id} thread_id={thread_id}") + + # 새로운 입력(input)이 있으면 특정 체크포인트 ID에서 시작하지 않도록 함 (이전 턴의 인터럽트 오염 방지) + actual_checkpoint = request.checkpoint + if request.input and actual_checkpoint and isinstance(actual_checkpoint, dict): + if actual_checkpoint.get("checkpoint_id"): + print(f"[create_run] Clearing checkpoint_id for new turn input to prevent stale resumption.") + actual_checkpoint = actual_checkpoint.copy() + actual_checkpoint["checkpoint_id"] = None + + run_config = create_run_config(run_id, thread_id, user, request.config or {}, actual_checkpoint) # 어시스턴트 존재 여부를 검증하고 graph_id를 가져옵니다. # assistant UUID 대신 graph_id가 제공된 경우, 결정론적으로 매핑하고 @@ -388,7 +394,7 @@ async def create_run( context, stream_modes, None, # 충돌 방지를 위해 session 전달 안 함 - request.checkpoint, + actual_checkpoint, request.command, request.interrupt_before, request.interrupt_after, @@ -442,6 +448,7 @@ async def create_and_stream_run( - on_disconnect=cancel 옵션으로 클라이언트 연결 해제 시 실행 취소 가능 """ + print(f"[create_and_stream_run] request for thread_id={thread_id}: {request.model_dump()}") # resume 명령 요구사항을 조기에 검증 if request.command and request.command.get("resume") is not None: # 스레드가 존재하고 중단된 상태인지 확인 @@ -454,11 +461,17 @@ async def create_and_stream_run( run_id = str(uuid4()) + # 새로운 입력(input)이 있으면 특정 체크포인트 ID에서 시작하지 않도록 함 (이전 턴의 인터럽트 오염 방지) + actual_checkpoint = request.checkpoint + if request.input and actual_checkpoint and isinstance(actual_checkpoint, dict): + if actual_checkpoint.get("checkpoint_id"): + print(f"[create_and_stream_run] Clearing checkpoint_id for new turn input to prevent stale resumption.") + actual_checkpoint = actual_checkpoint.copy() + actual_checkpoint["checkpoint_id"] = None + # LangGraph 서비스 가져오기 langgraph_service = get_langgraph_service() - print( - f"[create_and_stream_run] scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}" - ) + print(f"[create_and_stream_run] Scheduling background task run_id={run_id} thread_id={thread_id}") # 어시스턴트 존재 여부를 검증하고 graph_id를 가져옵니다. # graph_id를 전달하면 결정론적 어시스턴트 ID로 매핑합니다. @@ -538,7 +551,7 @@ async def create_and_stream_run( context, stream_modes, None, # 충돌 방지를 위해 session 전달 안 함 - request.checkpoint, + actual_checkpoint, request.command, request.interrupt_before, request.interrupt_after, @@ -1340,6 +1353,7 @@ async def execute_run_async( if isinstance(event_data, dict) and "__interrupt__" in event_data: has_interrupt = True + print(f"[execute_run_async] Detected interrupt via event: {event_data}") # 최종 출력 추적 if isinstance(raw_event, tuple): @@ -1349,6 +1363,19 @@ async def execute_run_async( # 튜플이 아닌 이벤트는 values 모드 final_output = raw_event + # 스트림 완료 후 스레드 상태를 확인하여 interrupt 여부 판단 + # LangGraph의 interrupt()는 이벤트에 __interrupt__를 추가하지만, + # 더 확실한 방법은 스레드 상태의 'next' 필드를 확인하는 것 + try: + thread_state = await graph.aget_state(run_config) + # 'next' 필드가 있으면 그래프가 중단되어 다음 노드를 기다리는 상태 + if thread_state and hasattr(thread_state, 'next') and thread_state.next: + has_interrupt = True + print(f"[execute_run_async] Detected interrupt via thread state: next={thread_state.next}") + except Exception as e: + # thread state 확인 실패 시 기존 이벤트 기반 감지 결과 사용 + print(f"[execute_run_async] Failed to check thread state: {e}") + if has_interrupt: await update_run_status(run_id, "interrupted", output=final_output or {}, session=session) if not session: