diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index c74dba3..fc28dac 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -3,7 +3,7 @@ import { requireAuth } from "@/lib/require-auth"; import { isValidVoiceUploadToken } from "@/lib/voice-upload-tokens"; import { getGatewayClient, holdClient, releaseClient } from "@/lib/gateway-chat-pool"; import { db, withRetry } from "@/db"; -import { agents, channelMembers, chatMessages, chatRuns, chatSessions, chatThreads } from "@/db/schema"; +import { agents, channelMembers, channels, chatMessages, chatRuns, chatSessions, chatThreads } from "@/db/schema"; import { eq, desc, and, isNull, sql } from "drizzle-orm"; import { publishChatEvent, publishChatProgressEvent } from "@/lib/chat-pubsub"; import { isAssistantDeliveryPlaceholder, selectRecoveredAssistantText } from "@/lib/chat-recovery"; @@ -27,7 +27,8 @@ const ACTIVE_HISTORY_POLL_INTERVAL_MS = 1_500; const ACTIVE_HISTORY_POLL_MAX_ATTEMPTS = 80; const AGENT_MODE_THINKING_LEVEL = "low"; const CHANNEL_AGENT_SPEAKING_ROLES = new Set(["owner", "admin", "member", "contributor"]); -const CHANNEL_AGENT_SPEAKING_MODES = new Set(["mention_only", "proactive", "on_call"]); +const CHANNEL_AGENT_MENTION_MODES = new Set(["mention_only", "proactive", "on_call"]); +const CHANNEL_AGENT_ACTIVE_MODES = new Set(["proactive", "on_call"]); type ChatProgressEventName = | "run_started" @@ -812,6 +813,9 @@ async function resolveChannelAgentInvocationViolation(params: { channelId: string; companyId: string | null; agentCallsign: string; + agentMode: boolean; + invocationMode: "active" | "mention"; + messageText: string; }) { if (!db) return null; const callsign = params.agentCallsign.trim().toLowerCase(); @@ -821,6 +825,7 @@ async function resolveChannelAgentInvocationViolation(params: { db!.select({ agent: agents, member: channelMembers, + channel: channels, }) .from(agents) .innerJoin( @@ -830,6 +835,7 @@ async function resolveChannelAgentInvocationViolation(params: { eq(channelMembers.channelId, params.channelId), ), ) + .innerJoin(channels, eq(channels.id, channelMembers.channelId)) .where(sql`lower(${agents.callsign}) = ${callsign}`) .limit(1) ); @@ -838,9 +844,19 @@ async function resolveChannelAgentInvocationViolation(params: { if (!CHANNEL_AGENT_SPEAKING_ROLES.has(row.member.role)) { return "Agent cannot post in this channel."; } - if (!CHANNEL_AGENT_SPEAKING_MODES.has(row.member.agentParticipationMode ?? "mention_only")) { + const participationMode = row.member.agentParticipationMode ?? "mention_only"; + const isDirectAgentDm = row.channel.type === "dm"; + const isActiveParticipant = isDirectAgentDm || CHANNEL_AGENT_ACTIVE_MODES.has(participationMode); + const isExplicitMention = !params.agentMode && + params.invocationMode === "mention" && + messageMentionsAgent(params.messageText, callsign, row.agent.name); + + if (isExplicitMention && !CHANNEL_AGENT_MENTION_MODES.has(participationMode) && !isDirectAgentDm) { return "Agent is not configured to respond in this channel."; } + if (!isExplicitMention && !isActiveParticipant) { + return "Agent is not an active participant in this channel."; + } if (params.companyId) { const ownerCompanyId = row.agent.ownerCompanyId ?? row.agent.companyId ?? null; if (row.agent.ownerType !== "company" || ownerCompanyId !== params.companyId) { @@ -850,6 +866,23 @@ async function resolveChannelAgentInvocationViolation(params: { return null; } +function escapeRegExp(value: string) { + return value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); +} + +function messageMentionsAgent(message: string, callsign: string, name?: string | null) { + const aliases = [callsign, name].filter((value): value is string => Boolean(value?.trim())); + return aliases.some((alias) => { + const escaped = escapeRegExp(alias.trim().toLowerCase()); + return [ + new RegExp(`^@${escaped}\\b`, "i"), + new RegExp(`^${escaped}[,:\\s]`, "i"), + new RegExp(`^hey\\s+${escaped}\\b`, "i"), + new RegExp(`\\b@${escaped}\\b`, "i"), + ].some((pattern) => pattern.test(message)); + }); +} + export async function POST(request: NextRequest) { const bearerToken = request.headers.get("authorization")?.replace(/^Bearer\s+/i, ""); const authError = isValidVoiceUploadToken(bearerToken) ? null : await requireAuth(request); @@ -867,6 +900,7 @@ export async function POST(request: NextRequest) { channelId: bodyChannelId, sessionKey: bodySessionKey, agentMode: bodyAgentMode, + channelInvocationMode: bodyChannelInvocationMode, clientVisibility: bodyClientVisibility, notifyOnCompletion: bodyNotifyOnCompletion, threadContext: bodyThreadContext, @@ -927,6 +961,9 @@ export async function POST(request: NextRequest) { request.cookies.get("active_workspace")?.value || null; const channelId = typeof bodyChannelId === "string" && bodyChannelId.trim() ? bodyChannelId.trim() : null; + const channelInvocationMode = firstString(bodyChannelInvocationMode)?.toLowerCase() === "mention" + ? "mention" + : "active"; const persistenceScope: ChatPersistenceScope = { companyId, workspaceId, channelId }; if (companyId || workspaceId) { const accessibleWorkspace = await resolveAccessibleWorkspace({ @@ -950,6 +987,9 @@ export async function POST(request: NextRequest) { channelId, companyId: accessibleWorkspace.companyId, agentCallsign: targetAgentCallsign || agentId, + agentMode: bodyAgentMode === true, + invocationMode: channelInvocationMode, + messageText: asString(lastUserMessage.content) ?? "", }); if (violation) { return Response.json({ error: violation }, { status: 403 }); diff --git a/src/app/api/runtimes/[id]/talk/realtime/session/route.test.ts b/src/app/api/runtimes/[id]/talk/realtime/session/route.test.ts index f6f2e7c..a772ec2 100644 --- a/src/app/api/runtimes/[id]/talk/realtime/session/route.test.ts +++ b/src/app/api/runtimes/[id]/talk/realtime/session/route.test.ts @@ -19,13 +19,19 @@ type ChannelMemberRow = { agentParticipationMode: string | null; }; -type DbRow = RuntimeRow | AgentRow | ChannelMemberRow; +type ChannelRow = { + id: string; + type: "channel" | "dm"; +}; + +type DbRow = RuntimeRow | AgentRow | ChannelMemberRow | ChannelRow; type Field = { table: string; key: string }; type Predicate = (row: DbRow) => boolean; -const { mockRuntimeRows, mockAgentRows, mockChannelMemberRows, mockGetGatewayClientForRuntime } = vi.hoisted(() => ({ +const { mockRuntimeRows, mockAgentRows, mockChannelRows, mockChannelMemberRows, mockGetGatewayClientForRuntime } = vi.hoisted(() => ({ mockRuntimeRows: [] as RuntimeRow[], mockAgentRows: [] as AgentRow[], + mockChannelRows: [] as ChannelRow[], mockChannelMemberRows: [] as ChannelMemberRow[], mockGetGatewayClientForRuntime: vi.fn(), })); @@ -41,6 +47,11 @@ vi.mock("@/db/schema", () => ({ id: { table: "agents", key: "id" }, callsign: { table: "agents", key: "callsign" }, }, + channels: { + __table: "channels", + id: { table: "channels", key: "id" }, + type: { table: "channels", key: "type" }, + }, channelMembers: { __table: "channelMembers", id: { table: "channelMembers", key: "id" }, @@ -61,6 +72,7 @@ vi.mock("drizzle-orm", () => ({ function rowsForTable(table: { __table: string }) { if (table.__table === "companyRuntimes") return mockRuntimeRows; if (table.__table === "agents") return mockAgentRows; + if (table.__table === "channels") return mockChannelRows; if (table.__table === "channelMembers") return mockChannelMemberRows; return []; } @@ -106,6 +118,7 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => { vi.clearAllMocks(); mockRuntimeRows.length = 0; mockAgentRows.length = 0; + mockChannelRows.length = 0; mockChannelMemberRows.length = 0; }); @@ -210,6 +223,42 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => { }); mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" }); mockAgentRows.push({ id: "agent_1", callsign: "neo" }); + mockChannelRows.push({ id: "channel_crew", type: "channel" }); + mockChannelMemberRows.push({ + id: "member_1", + channelId: "channel_crew", + memberType: "agent", + agentId: "agent_1", + role: "member", + agentParticipationMode: "on_call", + }); + mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeTalkSession }); + + const response = await POST( + new Request("http://localhost/api/runtimes/rt_1/talk/realtime/session", { + method: "POST", + body: JSON.stringify({ + sessionKey: "main", + agentId: "main", + channelAgentId: "neo", + channelId: "channel_crew", + }), + }), + { params: Promise.resolve({ id: "rt_1" }) }, + ); + + expect(response.status).toBe(200); + expect(realtimeTalkSession).toHaveBeenCalledWith(expect.objectContaining({ + agentId: "main", + sessionKey: "main", + })); + }); + + it("rejects mention-only agents for shared-channel realtime sessions", async () => { + const realtimeTalkSession = vi.fn(); + mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" }); + mockAgentRows.push({ id: "agent_1", callsign: "neo" }); + mockChannelRows.push({ id: "channel_crew", type: "channel" }); mockChannelMemberRows.push({ id: "member_1", channelId: "channel_crew", @@ -233,6 +282,45 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => { { params: Promise.resolve({ id: "rt_1" }) }, ); + expect(response.status).toBe(403); + await expect(response.json()).resolves.toEqual({ + error: "Agent is not an active participant in this channel.", + }); + expect(mockGetGatewayClientForRuntime).not.toHaveBeenCalled(); + expect(realtimeTalkSession).not.toHaveBeenCalled(); + }); + + it("allows direct agent DM realtime sessions even when the stored mode is mention-only", async () => { + const realtimeTalkSession = vi.fn().mockResolvedValue({ + transport: "gateway-relay", + relaySessionId: "relay_1", + }); + mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" }); + mockAgentRows.push({ id: "agent_1", callsign: "neo" }); + mockChannelRows.push({ id: "dm_neo", type: "dm" }); + mockChannelMemberRows.push({ + id: "member_1", + channelId: "dm_neo", + memberType: "agent", + agentId: "agent_1", + role: "member", + agentParticipationMode: "mention_only", + }); + mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeTalkSession }); + + const response = await POST( + new Request("http://localhost/api/runtimes/rt_1/talk/realtime/session", { + method: "POST", + body: JSON.stringify({ + sessionKey: "main", + agentId: "main", + channelAgentId: "neo", + channelId: "dm_neo", + }), + }), + { params: Promise.resolve({ id: "rt_1" }) }, + ); + expect(response.status).toBe(200); expect(realtimeTalkSession).toHaveBeenCalledWith(expect.objectContaining({ agentId: "main", diff --git a/src/app/api/runtimes/[id]/talk/realtime/session/route.ts b/src/app/api/runtimes/[id]/talk/realtime/session/route.ts index d565b8b..d2817f4 100644 --- a/src/app/api/runtimes/[id]/talk/realtime/session/route.ts +++ b/src/app/api/runtimes/[id]/talk/realtime/session/route.ts @@ -1,7 +1,7 @@ import { NextResponse } from "next/server"; import { and, eq } from "drizzle-orm"; import { db, withRetry } from "@/db"; -import { agents, channelMembers, companyRuntimes } from "@/db/schema"; +import { agents, channelMembers, channels, companyRuntimes } from "@/db/schema"; import { buildRuntimeReadWhere, getAgentAccessContext } from "@/lib/agent-access"; import { getGatewayClientForRuntime } from "@/lib/gateway-chat-pool"; @@ -10,7 +10,7 @@ export const dynamic = "force-dynamic"; const REALTIME_SLOW_SPEECH_SILENCE_MS = 2000; const REALTIME_SLOW_SPEECH_PREFIX_PADDING_MS = 500; const CHANNEL_AGENT_SPEAKING_ROLES = new Set(["owner", "admin", "member", "contributor"]); -const CHANNEL_AGENT_SPEAKING_MODES = new Set(["mention_only", "proactive", "on_call"]); +const CHANNEL_AGENT_ACTIVE_MODES = new Set(["proactive", "on_call"]); export async function POST( request: Request, @@ -70,6 +70,15 @@ async function resolveRealtimeChannelAgentViolation(params: { const callsign = params.agentCallsign?.trim(); if (!callsign) return "Channel agent mention is required."; + const [channel] = await withRetry(() => + db! + .select({ type: channels.type }) + .from(channels) + .where(eq(channels.id, params.channelId)) + .limit(1) + ); + if (!channel) return "Agent is not a member of this channel."; + const [agent] = await withRetry(() => db! .select({ id: agents.id }) @@ -98,8 +107,8 @@ async function resolveRealtimeChannelAgentViolation(params: { if (!CHANNEL_AGENT_SPEAKING_ROLES.has(member.role)) { return "Agent cannot post in this channel."; } - if (!CHANNEL_AGENT_SPEAKING_MODES.has(member.agentParticipationMode ?? "mention_only")) { - return "Agent is not configured to respond in this channel."; + if (channel.type !== "dm" && !CHANNEL_AGENT_ACTIVE_MODES.has(member.agentParticipationMode ?? "mention_only")) { + return "Agent is not an active participant in this channel."; } return null; } diff --git a/src/app/chat/page.tsx b/src/app/chat/page.tsx index 35af0ba..44c55b0 100644 --- a/src/app/chat/page.tsx +++ b/src/app/chat/page.tsx @@ -2577,30 +2577,55 @@ export default function ChatPage() { () => new Map(agents.map((agent) => [agent.id, agent])), [agents] ); - const eligibleChannelAgents = useMemo(() => { + const mentionableChannelAgents = useMemo(() => { if (!activeChannel) return []; const speakingRoles = new Set(["owner", "admin", "member", "contributor"]); - const speakingModes = new Set(["mention_only", "proactive", "on_call"]); + const mentionModes = new Set(["mention_only", "proactive", "on_call"]); return agents.filter((agent) => { const membership = channelAgentMemberById.get(agent.id); if (!membership) return false; if (!speakingRoles.has(membership.role)) return false; - if (!speakingModes.has(membership.agentParticipationMode ?? "mention_only")) return false; + if (!mentionModes.has(membership.agentParticipationMode ?? "mention_only")) return false; return true; }); }, [activeChannel, agents, channelAgentMemberById]); - const selectedAgentCanSpeakInActiveChannel = useMemo(() => { + const activeChannelAgents = useMemo(() => { + if (!activeChannel) return []; + const speakingRoles = new Set(["owner", "admin", "member", "contributor"]); + const activeModes = new Set(["proactive", "on_call"]); + return agents.filter((agent) => { + const membership = channelAgentMemberById.get(agent.id); + if (!membership) return false; + if (!speakingRoles.has(membership.role)) return false; + if (activeChannel.type === "dm") return true; + if (!activeModes.has(membership.agentParticipationMode ?? "mention_only")) return false; + return true; + }); + }, [activeChannel, agents, channelAgentMemberById]); + const selectedAgentIsActiveInChannel = useMemo(() => { if (!activeChannelId) return true; if (!selectedAgent) return false; - return eligibleChannelAgents.some((agent) => sameAgent(agent, selectedAgent)); - }, [activeChannelId, eligibleChannelAgents, selectedAgent]); + return activeChannelAgents.some((agent) => sameAgent(agent, selectedAgent)); + }, [activeChannelAgents, activeChannelId, selectedAgent]); const channelAgentModeBlockReason = useMemo(() => { - if (!activeChannelId || selectedAgentCanSpeakInActiveChannel) return null; + if (!activeChannelId || selectedAgentIsActiveInChannel) return null; const channelName = activeChannel?.name ? `#${activeChannel.name}` : "this channel"; + if (activeChannelAgents.length === 0) { + return activeChannel?.type === "dm" + ? `Start a direct agent DM before using agent mode.` + : `Invite an agent as On call or Proactive before using agent mode in ${channelName}.`; + } return selectedAgent - ? `Invite @${selectedAgent.callsign} to ${channelName} before using agent mode.` - : `Select an agent member of ${channelName} before using agent mode.`; - }, [activeChannel?.name, activeChannelId, selectedAgent, selectedAgentCanSpeakInActiveChannel]); + ? `Set @${selectedAgent.callsign} to On call or Proactive before using agent mode in ${channelName}.` + : `Select an active agent participant in ${channelName} before using agent mode.`; + }, [activeChannel?.name, activeChannel?.type, activeChannelAgents.length, activeChannelId, selectedAgent, selectedAgentIsActiveInChannel]); + + useEffect(() => { + if (!activeChannelId || activeChannel?.type === "dm") return; + if (activeChannelAgents.length !== 1) return; + if (selectedAgent && sameAgent(selectedAgent, activeChannelAgents[0])) return; + setSelectedAgent(activeChannelAgents[0]); + }, [activeChannel?.type, activeChannelAgents, activeChannelId, selectedAgent]); const visibleMessages = useMemo( () => uniqueMessagesById( messages.filter(isVisibleChatMessage) @@ -3650,7 +3675,7 @@ export default function ChatPage() { // --- Wake word detection: check if user is addressing a specific agent --- const lowerTrimmed = trimmed.toLowerCase(); let wakeAgent: Agent | null = null; - const mentionableAgents = activeChannelId ? eligibleChannelAgents : agents; + const mentionableAgents = activeChannelId ? mentionableChannelAgents : agents; for (const agent of mentionableAgents) { const callsign = agent.callsign.toLowerCase(); const name = agent.name.toLowerCase(); @@ -3671,8 +3696,8 @@ export default function ChatPage() { } } const addressedAgent = wakeAgent ?? ( - activeChannel?.type === "dm" && eligibleChannelAgents.length === 1 - ? eligibleChannelAgents[0] + activeChannel?.type === "dm" && activeChannelAgents.length === 1 + ? activeChannelAgents[0] : null ); @@ -3866,6 +3891,9 @@ export default function ChatPage() { setChannelNotice(channelAgentModeBlockReason); return; } + const shouldSendToActiveChannelAgent = activeChannelId + ? Boolean(addressedAgent || options.forceVoiceResponse || selectedAgentIsActiveInChannel) + : true; const respondingDelegatedViaAgent = addressedAgent && defaultAgent && !sameAgent(addressedAgent, defaultAgent) ? defaultAgent : delegatedViaAgent; @@ -3912,7 +3940,7 @@ export default function ChatPage() { }); // User message persisted server-side in /api/chat route setInput(""); - if (activeChannelId && !addressedAgent && !options.forceVoiceResponse) { + if (activeChannelId && !shouldSendToActiveChannelAgent) { if (chatCompanyId || chatWorkspaceId) { try { const res = await fetch("/api/chat/messages", { @@ -4018,9 +4046,12 @@ export default function ChatPage() { companyId: chatCompanyId, workspaceId: chatWorkspaceId, channelId: activeChannelId, - metadata, - pageContext, - sessionKey: requestSessionKey, + channelInvocationMode: activeChannelId + ? addressedAgent ? "mention" : "active" + : undefined, + metadata, + pageContext, + sessionKey: requestSessionKey, agentMode: voiceMode === "agent", clientVisibility: typeof document !== "undefined" && document.hidden ? "hidden" : "visible", notifyOnCompletion: true, @@ -4328,7 +4359,7 @@ export default function ChatPage() { }, 0); } }, - [visibleMessages, queueSentenceForTTS, selectedAgent, speakResponses, agentAudioMuted, pendingFiles, agents, activeChannel?.type, eligibleChannelAgents, isPaused, stopWords, activeChannelId, chatCompanyId, chatWorkspaceId, selectedSessionKey, defaultAgent, delegatedViaAgent, persistExecutionSnapshot, refreshSessionPreview, enqueueMainMessage, setMainLoading, agentCallsign, voiceMode, pageContext, channelAgentModeBlockReason] + [visibleMessages, queueSentenceForTTS, selectedAgent, speakResponses, agentAudioMuted, pendingFiles, agents, activeChannel?.type, activeChannelAgents, mentionableChannelAgents, isPaused, stopWords, activeChannelId, chatCompanyId, chatWorkspaceId, selectedSessionKey, defaultAgent, delegatedViaAgent, persistExecutionSnapshot, refreshSessionPreview, enqueueMainMessage, setMainLoading, agentCallsign, voiceMode, pageContext, channelAgentModeBlockReason, selectedAgentIsActiveInChannel] ); const sendThreadMessage = useCallback(async ( @@ -4448,6 +4479,7 @@ export default function ChatPage() { companyId: chatCompanyId, workspaceId: chatWorkspaceId, channelId: activeChannelId, + channelInvocationMode: activeChannelId ? "active" : undefined, metadata, pageContext, sessionKey: thread.sessionKey, @@ -4925,9 +4957,9 @@ export default function ChatPage() { const activeConversationLabel = activeChannel ? `${activeChannel.type === "dm" ? "" : "# "}${activeChannel.name ?? "untitled"}` : agentCallsign; - const eligibleChannelAgentCallsigns = eligibleChannelAgents.map((agent) => `@${agent.callsign}`); - const channelAgentHint = eligibleChannelAgentCallsigns.length > 0 - ? `mention ${eligibleChannelAgentCallsigns.slice(0, 2).join(" or ")} to invite an agent` + const mentionableChannelAgentCallsigns = mentionableChannelAgents.map((agent) => `@${agent.callsign}`); + const channelAgentHint = mentionableChannelAgentCallsigns.length > 0 + ? `mention ${mentionableChannelAgentCallsigns.slice(0, 2).join(" or ")} to invite an agent` : "no shared agents in this channel"; const composerPlaceholder = isPaused ? `Say "${agentCallsign}" or @${agentCallsign} to resume...`