Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/agent_chat_cli/components/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from agent_chat_cli.components.messages import (
AgentMessage,
Message,
MessageType,
RoleType,
SystemMessage,
ToolMessage,
UserMessage,
Expand All @@ -20,22 +20,22 @@ def _create_message_widget(
self, message: Message
) -> SystemMessage | UserMessage | AgentMessage | ToolMessage:
match message.type:
case MessageType.SYSTEM:
case RoleType.SYSTEM:
system_widget = SystemMessage()
system_widget.message = message.content
return system_widget

case MessageType.USER:
case RoleType.USER:
user_widget = UserMessage()
user_widget.message = message.content
return user_widget

case MessageType.AGENT:
case RoleType.AGENT:
agent_widget = AgentMessage()
agent_widget.message = message.content
return agent_widget

case MessageType.TOOL:
case RoleType.TOOL:
tool_widget = ToolMessage()

if message.metadata and "tool_name" in message.metadata:
Expand Down
12 changes: 6 additions & 6 deletions src/agent_chat_cli/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from agent_chat_cli.utils import get_tool_info, format_tool_input


class MessageType(Enum):
class RoleType(Enum):
SYSTEM = "system"
USER = "user"
AGENT = "agent"
Expand All @@ -19,26 +19,26 @@ class MessageType(Enum):

@dataclass
class Message:
type: MessageType
type: RoleType
content: str
metadata: dict[str, Any] | None = None

@classmethod
def system(cls, content: str) -> "Message":
return cls(type=MessageType.SYSTEM, content=content)
return cls(type=RoleType.SYSTEM, content=content)

@classmethod
def user(cls, content: str) -> "Message":
return cls(type=MessageType.USER, content=content)
return cls(type=RoleType.USER, content=content)

@classmethod
def agent(cls, content: str) -> "Message":
return cls(type=MessageType.AGENT, content=content)
return cls(type=RoleType.AGENT, content=content)

@classmethod
def tool(cls, tool_name: str, content: str) -> "Message":
return cls(
type=MessageType.TOOL, content=content, metadata={"tool_name": tool_name}
type=RoleType.TOOL, content=content, metadata={"tool_name": tool_name}
)


Expand Down
29 changes: 5 additions & 24 deletions src/agent_chat_cli/core/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from agent_chat_cli.utils.enums import ControlCommand
from agent_chat_cli.components.chat_history import ChatHistory
from agent_chat_cli.components.messages import Message, MessageType
from agent_chat_cli.components.messages import RoleType
from agent_chat_cli.components.tool_permission_prompt import ToolPermissionPrompt
from agent_chat_cli.utils.logger import log_json

Expand All @@ -17,34 +17,15 @@ def __init__(self, app: "AgentChatCLIApp") -> None:
def quit(self) -> None:
self.app.exit()

async def add_message_to_chat_history(
self, type: MessageType, content: str
) -> None:
match type:
case MessageType.USER:
message = Message.user(content)
case MessageType.SYSTEM:
message = Message.system(content)
case MessageType.AGENT:
message = Message.agent(content)
case _:
raise ValueError(f"Unsupported message type: {type}")

chat_history = self.app.query_one(ChatHistory)
chat_history.add_message(message)

async def submit_user_message(self, message: str) -> None:
chat_history = self.app.query_one(ChatHistory)
chat_history.add_message(Message.user(message))
self.app.ui_state.start_thinking()
await self.app.ui_state.scroll_to_bottom()
await self.app.renderer.add_message(RoleType.USER, message)
await self._query(message)

async def post_system_message(self, message: str) -> None:
await self.add_message_to_chat_history(MessageType.SYSTEM, message)
await self.app.renderer.add_message(RoleType.SYSTEM, message)

async def render_message(self, message) -> None:
await self.app.renderer.render_message(message)
async def handle_app_event(self, event) -> None:
await self.app.renderer.handle_app_event(event)

async def interrupt(self) -> None:
permission_prompt = self.app.query_one(ToolPermissionPrompt)
Expand Down
30 changes: 15 additions & 15 deletions src/agent_chat_cli/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_available_servers,
get_sdk_config,
)
from agent_chat_cli.utils.enums import AgentMessageType, ContentType, ControlCommand
from agent_chat_cli.utils.enums import AppEventType, ContentType, ControlCommand
from agent_chat_cli.utils.logger import log_json
from agent_chat_cli.utils.mcp_server_status import MCPServerStatus

Expand All @@ -33,8 +33,8 @@


@dataclass
class AgentMessage:
type: AgentMessageType
class AppEvent:
type: AppEventType
data: Any


Expand Down Expand Up @@ -93,8 +93,8 @@ async def start(self) -> None:

await self._handle_message(message)

await self.app.actions.render_message(
AgentMessage(type=AgentMessageType.RESULT, data=None)
await self.app.actions.handle_app_event(
AppEvent(type=AppEventType.RESULT, data=None)
)

async def _initialize_client(self, mcp_servers: dict) -> None:
Expand All @@ -115,7 +115,7 @@ async def _handle_message(self, message: Message) -> None:
if isinstance(message, SystemMessage):
log_json(message.data)

if message.subtype == AgentMessageType.INIT.value and message.data.get(
if message.subtype == AppEventType.INIT.value and message.data.get(
"session_id"
):
# When initializing the chat, we store the session_id for later
Expand All @@ -136,9 +136,9 @@ async def _handle_message(self, message: Message) -> None:
text_chunk = delta.get("text", "")

if text_chunk:
await self.app.actions.render_message(
AgentMessage(
type=AgentMessageType.STREAM_EVENT,
await self.app.actions.handle_app_event(
AppEvent(
type=AppEventType.STREAM_EVENT,
data={"text": text_chunk},
)
)
Expand All @@ -164,9 +164,9 @@ async def _handle_message(self, message: Message) -> None:
)

# Finally, post the agent assistant response
await self.app.actions.render_message(
AgentMessage(
type=AgentMessageType.ASSISTANT,
await self.app.actions.handle_app_event(
AppEvent(
type=AppEventType.ASSISTANT,
data={"content": content},
)
)
Expand All @@ -181,9 +181,9 @@ async def _can_use_tool(

# Handle permission request queue sequentially
async with self.permission_lock:
await self.app.actions.render_message(
AgentMessage(
type=AgentMessageType.TOOL_PERMISSION_REQUEST,
await self.app.actions.handle_app_event(
AppEvent(
type=AppEventType.TOOL_PERMISSION_REQUEST,
data={
"tool_name": tool_name,
"tool_input": tool_input,
Expand Down
88 changes: 49 additions & 39 deletions src/agent_chat_cli/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from agent_chat_cli.components.chat_history import ChatHistory
from agent_chat_cli.components.messages import (
AgentMessage as AgentMessageWidget,
MessageType,
Message,
RoleType,
ToolMessage,
)
from agent_chat_cli.core.agent_loop import AgentMessage
from agent_chat_cli.utils.enums import AgentMessageType, ContentType
from agent_chat_cli.core.agent_loop import AppEvent
from agent_chat_cli.utils.enums import AppEventType, ContentType
from agent_chat_cli.utils.logger import log_json

if TYPE_CHECKING:
Expand All @@ -32,31 +33,48 @@ def __init__(self, app: "AgentChatCLIApp") -> None:
self.app = app
self._stream = StreamBuffer()

async def render_message(self, message: AgentMessage) -> None:
match message.type:
case AgentMessageType.STREAM_EVENT:
await self._render_stream_event(message)
async def handle_app_event(self, event: AppEvent) -> None:
match event.type:
case AppEventType.STREAM_EVENT:
await self._render_stream_event(event)

case AgentMessageType.ASSISTANT:
await self._render_assistant_message(message)
case AppEventType.ASSISTANT:
await self._render_assistant_message(event)

case AgentMessageType.SYSTEM:
await self._render_system_message(message)
case AppEventType.SYSTEM:
await self._render_system_message(event)

case AgentMessageType.USER:
await self._render_user_message(message)
case AppEventType.USER:
await self._render_user_message(event)

case AgentMessageType.TOOL_PERMISSION_REQUEST:
await self._render_tool_permission_request(message)
case AppEventType.TOOL_PERMISSION_REQUEST:
await self._render_tool_permission_request(event)

case AgentMessageType.RESULT:
case AppEventType.RESULT:
await self._on_complete()

if message.type is not AgentMessageType.RESULT:
if event.type is not AppEventType.RESULT:
await self.app.ui_state.scroll_to_bottom()

async def _render_stream_event(self, message: AgentMessage) -> None:
text_chunk = message.data.get("text", "")
async def add_message(self, type: RoleType, content: str) -> None:
match type:
case RoleType.USER:
message = Message.user(content)
case RoleType.SYSTEM:
message = Message.system(content)
case RoleType.AGENT:
message = Message.agent(content)
case _:
raise ValueError(f"Unsupported message type: {type}")

chat_history = self.app.query_one(ChatHistory)
chat_history.add_message(message)

self.app.ui_state.start_thinking()
await self.app.ui_state.scroll_to_bottom()

async def _render_stream_event(self, event: AppEvent) -> None:
text_chunk = event.data.get("text", "")

if not text_chunk:
return
Expand All @@ -77,8 +95,8 @@ async def _render_stream_event(self, message: AgentMessage) -> None:
markdown = self._stream.widget.query_one(Markdown)
markdown.update(self._stream.text)

async def _render_assistant_message(self, message: AgentMessage) -> None:
content_blocks = message.data.get("content", [])
async def _render_assistant_message(self, event: AppEvent) -> None:
content_blocks = event.data.get("content", [])
chat_history = self.app.query_one(ChatHistory)

for block in content_blocks:
Expand All @@ -97,35 +115,27 @@ async def _render_assistant_message(self, message: AgentMessage) -> None:

await chat_history.mount(tool_msg)

async def _render_system_message(self, message: AgentMessage) -> None:
system_content = (
message.data if isinstance(message.data, str) else str(message.data)
)
async def _render_system_message(self, event: AppEvent) -> None:
system_content = event.data if isinstance(event.data, str) else str(event.data)

await self.app.actions.add_message_to_chat_history(
MessageType.SYSTEM, system_content
)
await self.add_message(RoleType.SYSTEM, system_content)

async def _render_user_message(self, message: AgentMessage) -> None:
user_content = (
message.data if isinstance(message.data, str) else str(message.data)
)
async def _render_user_message(self, event: AppEvent) -> None:
user_content = event.data if isinstance(event.data, str) else str(event.data)

await self.app.actions.add_message_to_chat_history(
MessageType.USER, user_content
)
await self.add_message(RoleType.USER, user_content)

async def _render_tool_permission_request(self, message: AgentMessage) -> None:
async def _render_tool_permission_request(self, event: AppEvent) -> None:
log_json(
{
"event": "showing_permission_prompt",
"tool_name": message.data.get("tool_name", ""),
"tool_name": event.data.get("tool_name", ""),
}
)

self.app.ui_state.show_permission_prompt(
tool_name=message.data.get("tool_name", ""),
tool_input=message.data.get("tool_input", {}),
tool_name=event.data.get("tool_name", ""),
tool_input=event.data.get("tool_input", {}),
)

async def _on_complete(self) -> None:
Expand Down
Loading