Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { requireAuth } from "@/lib/require-auth";
import { isValidVoiceUploadToken } from "@/lib/voice-upload-tokens";
import { getGatewayClient, holdClient, releaseClient } from "@/lib/gateway-chat-pool";
import { db, withRetry } from "@/db";
import { agents, channelMembers, chatMessages, chatRuns, chatSessions, chatThreads } from "@/db/schema";
import { agents, channelMembers, channels, chatMessages, chatRuns, chatSessions, chatThreads } from "@/db/schema";
import { eq, desc, and, isNull, sql } from "drizzle-orm";
import { publishChatEvent, publishChatProgressEvent } from "@/lib/chat-pubsub";
import { isAssistantDeliveryPlaceholder, selectRecoveredAssistantText } from "@/lib/chat-recovery";
Expand All @@ -27,7 +27,8 @@ const ACTIVE_HISTORY_POLL_INTERVAL_MS = 1_500;
const ACTIVE_HISTORY_POLL_MAX_ATTEMPTS = 80;
const AGENT_MODE_THINKING_LEVEL = "low";
const CHANNEL_AGENT_SPEAKING_ROLES = new Set(["owner", "admin", "member", "contributor"]);
const CHANNEL_AGENT_SPEAKING_MODES = new Set(["mention_only", "proactive", "on_call"]);
const CHANNEL_AGENT_MENTION_MODES = new Set(["mention_only", "proactive", "on_call"]);
const CHANNEL_AGENT_ACTIVE_MODES = new Set(["proactive", "on_call"]);

type ChatProgressEventName =
| "run_started"
Expand Down Expand Up @@ -812,6 +813,9 @@ async function resolveChannelAgentInvocationViolation(params: {
channelId: string;
companyId: string | null;
agentCallsign: string;
agentMode: boolean;
invocationMode: "active" | "mention";
messageText: string;
}) {
if (!db) return null;
const callsign = params.agentCallsign.trim().toLowerCase();
Expand All @@ -821,6 +825,7 @@ async function resolveChannelAgentInvocationViolation(params: {
db!.select({
agent: agents,
member: channelMembers,
channel: channels,
})
.from(agents)
.innerJoin(
Expand All @@ -830,6 +835,7 @@ async function resolveChannelAgentInvocationViolation(params: {
eq(channelMembers.channelId, params.channelId),
),
)
.innerJoin(channels, eq(channels.id, channelMembers.channelId))
.where(sql`lower(${agents.callsign}) = ${callsign}`)
.limit(1)
);
Expand All @@ -838,9 +844,19 @@ async function resolveChannelAgentInvocationViolation(params: {
if (!CHANNEL_AGENT_SPEAKING_ROLES.has(row.member.role)) {
return "Agent cannot post in this channel.";
}
if (!CHANNEL_AGENT_SPEAKING_MODES.has(row.member.agentParticipationMode ?? "mention_only")) {
const participationMode = row.member.agentParticipationMode ?? "mention_only";
const isDirectAgentDm = row.channel.type === "dm";
const isActiveParticipant = isDirectAgentDm || CHANNEL_AGENT_ACTIVE_MODES.has(participationMode);
const isExplicitMention = !params.agentMode &&
params.invocationMode === "mention" &&
messageMentionsAgent(params.messageText, callsign, row.agent.name);

if (isExplicitMention && !CHANNEL_AGENT_MENTION_MODES.has(participationMode) && !isDirectAgentDm) {
return "Agent is not configured to respond in this channel.";
}
if (!isExplicitMention && !isActiveParticipant) {
return "Agent is not an active participant in this channel.";
}
if (params.companyId) {
const ownerCompanyId = row.agent.ownerCompanyId ?? row.agent.companyId ?? null;
if (row.agent.ownerType !== "company" || ownerCompanyId !== params.companyId) {
Expand All @@ -850,6 +866,23 @@ async function resolveChannelAgentInvocationViolation(params: {
return null;
}

function escapeRegExp(value: string) {
return value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
}

function messageMentionsAgent(message: string, callsign: string, name?: string | null) {
const aliases = [callsign, name].filter((value): value is string => Boolean(value?.trim()));
return aliases.some((alias) => {
const escaped = escapeRegExp(alias.trim().toLowerCase());
return [
new RegExp(`^@${escaped}\\b`, "i"),
new RegExp(`^${escaped}[,:\\s]`, "i"),
new RegExp(`^hey\\s+${escaped}\\b`, "i"),
new RegExp(`\\b@${escaped}\\b`, "i"),
].some((pattern) => pattern.test(message));
});
}

export async function POST(request: NextRequest) {
const bearerToken = request.headers.get("authorization")?.replace(/^Bearer\s+/i, "");
const authError = isValidVoiceUploadToken(bearerToken) ? null : await requireAuth(request);
Expand All @@ -867,6 +900,7 @@ export async function POST(request: NextRequest) {
channelId: bodyChannelId,
sessionKey: bodySessionKey,
agentMode: bodyAgentMode,
channelInvocationMode: bodyChannelInvocationMode,
clientVisibility: bodyClientVisibility,
notifyOnCompletion: bodyNotifyOnCompletion,
threadContext: bodyThreadContext,
Expand Down Expand Up @@ -927,6 +961,9 @@ export async function POST(request: NextRequest) {
request.cookies.get("active_workspace")?.value ||
null;
const channelId = typeof bodyChannelId === "string" && bodyChannelId.trim() ? bodyChannelId.trim() : null;
const channelInvocationMode = firstString(bodyChannelInvocationMode)?.toLowerCase() === "mention"
? "mention"
: "active";
const persistenceScope: ChatPersistenceScope = { companyId, workspaceId, channelId };
if (companyId || workspaceId) {
const accessibleWorkspace = await resolveAccessibleWorkspace({
Expand All @@ -950,6 +987,9 @@ export async function POST(request: NextRequest) {
channelId,
companyId: accessibleWorkspace.companyId,
agentCallsign: targetAgentCallsign || agentId,
agentMode: bodyAgentMode === true,
invocationMode: channelInvocationMode,
messageText: asString(lastUserMessage.content) ?? "",
});
if (violation) {
return Response.json({ error: violation }, { status: 403 });
Expand Down
92 changes: 90 additions & 2 deletions src/app/api/runtimes/[id]/talk/realtime/session/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@ type ChannelMemberRow = {
agentParticipationMode: string | null;
};

type DbRow = RuntimeRow | AgentRow | ChannelMemberRow;
type ChannelRow = {
id: string;
type: "channel" | "dm";
};

type DbRow = RuntimeRow | AgentRow | ChannelMemberRow | ChannelRow;
type Field = { table: string; key: string };
type Predicate = (row: DbRow) => boolean;

const { mockRuntimeRows, mockAgentRows, mockChannelMemberRows, mockGetGatewayClientForRuntime } = vi.hoisted(() => ({
const { mockRuntimeRows, mockAgentRows, mockChannelRows, mockChannelMemberRows, mockGetGatewayClientForRuntime } = vi.hoisted(() => ({
mockRuntimeRows: [] as RuntimeRow[],
mockAgentRows: [] as AgentRow[],
mockChannelRows: [] as ChannelRow[],
mockChannelMemberRows: [] as ChannelMemberRow[],
mockGetGatewayClientForRuntime: vi.fn(),
}));
Expand All @@ -41,6 +47,11 @@ vi.mock("@/db/schema", () => ({
id: { table: "agents", key: "id" },
callsign: { table: "agents", key: "callsign" },
},
channels: {
__table: "channels",
id: { table: "channels", key: "id" },
type: { table: "channels", key: "type" },
},
channelMembers: {
__table: "channelMembers",
id: { table: "channelMembers", key: "id" },
Expand All @@ -61,6 +72,7 @@ vi.mock("drizzle-orm", () => ({
function rowsForTable(table: { __table: string }) {
if (table.__table === "companyRuntimes") return mockRuntimeRows;
if (table.__table === "agents") return mockAgentRows;
if (table.__table === "channels") return mockChannelRows;
if (table.__table === "channelMembers") return mockChannelMemberRows;
return [];
}
Expand Down Expand Up @@ -106,6 +118,7 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => {
vi.clearAllMocks();
mockRuntimeRows.length = 0;
mockAgentRows.length = 0;
mockChannelRows.length = 0;
mockChannelMemberRows.length = 0;
});

Expand Down Expand Up @@ -210,6 +223,42 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => {
});
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });
mockAgentRows.push({ id: "agent_1", callsign: "neo" });
mockChannelRows.push({ id: "channel_crew", type: "channel" });
mockChannelMemberRows.push({
id: "member_1",
channelId: "channel_crew",
memberType: "agent",
agentId: "agent_1",
role: "member",
agentParticipationMode: "on_call",
});
mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeTalkSession });

const response = await POST(
new Request("http://localhost/api/runtimes/rt_1/talk/realtime/session", {
method: "POST",
body: JSON.stringify({
sessionKey: "main",
agentId: "main",
channelAgentId: "neo",
channelId: "channel_crew",
}),
}),
{ params: Promise.resolve({ id: "rt_1" }) },
);

expect(response.status).toBe(200);
expect(realtimeTalkSession).toHaveBeenCalledWith(expect.objectContaining({
agentId: "main",
sessionKey: "main",
}));
});

it("rejects mention-only agents for shared-channel realtime sessions", async () => {
const realtimeTalkSession = vi.fn();
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });
mockAgentRows.push({ id: "agent_1", callsign: "neo" });
mockChannelRows.push({ id: "channel_crew", type: "channel" });
mockChannelMemberRows.push({
id: "member_1",
channelId: "channel_crew",
Expand All @@ -233,6 +282,45 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => {
{ params: Promise.resolve({ id: "rt_1" }) },
);

expect(response.status).toBe(403);
await expect(response.json()).resolves.toEqual({
error: "Agent is not an active participant in this channel.",
});
expect(mockGetGatewayClientForRuntime).not.toHaveBeenCalled();
expect(realtimeTalkSession).not.toHaveBeenCalled();
});

it("allows direct agent DM realtime sessions even when the stored mode is mention-only", async () => {
const realtimeTalkSession = vi.fn().mockResolvedValue({
transport: "gateway-relay",
relaySessionId: "relay_1",
});
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });
mockAgentRows.push({ id: "agent_1", callsign: "neo" });
mockChannelRows.push({ id: "dm_neo", type: "dm" });
mockChannelMemberRows.push({
id: "member_1",
channelId: "dm_neo",
memberType: "agent",
agentId: "agent_1",
role: "member",
agentParticipationMode: "mention_only",
});
mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeTalkSession });

const response = await POST(
new Request("http://localhost/api/runtimes/rt_1/talk/realtime/session", {
method: "POST",
body: JSON.stringify({
sessionKey: "main",
agentId: "main",
channelAgentId: "neo",
channelId: "dm_neo",
}),
}),
{ params: Promise.resolve({ id: "rt_1" }) },
);

expect(response.status).toBe(200);
expect(realtimeTalkSession).toHaveBeenCalledWith(expect.objectContaining({
agentId: "main",
Expand Down
17 changes: 13 additions & 4 deletions src/app/api/runtimes/[id]/talk/realtime/session/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { NextResponse } from "next/server";
import { and, eq } from "drizzle-orm";
import { db, withRetry } from "@/db";
import { agents, channelMembers, companyRuntimes } from "@/db/schema";
import { agents, channelMembers, channels, companyRuntimes } from "@/db/schema";
import { buildRuntimeReadWhere, getAgentAccessContext } from "@/lib/agent-access";
import { getGatewayClientForRuntime } from "@/lib/gateway-chat-pool";

Expand All @@ -10,7 +10,7 @@ export const dynamic = "force-dynamic";
const REALTIME_SLOW_SPEECH_SILENCE_MS = 2000;
const REALTIME_SLOW_SPEECH_PREFIX_PADDING_MS = 500;
const CHANNEL_AGENT_SPEAKING_ROLES = new Set(["owner", "admin", "member", "contributor"]);
const CHANNEL_AGENT_SPEAKING_MODES = new Set(["mention_only", "proactive", "on_call"]);
const CHANNEL_AGENT_ACTIVE_MODES = new Set(["proactive", "on_call"]);

export async function POST(
request: Request,
Expand Down Expand Up @@ -70,6 +70,15 @@ async function resolveRealtimeChannelAgentViolation(params: {
const callsign = params.agentCallsign?.trim();
if (!callsign) return "Channel agent mention is required.";

const [channel] = await withRetry(() =>
db!
.select({ type: channels.type })
.from(channels)
.where(eq(channels.id, params.channelId))
.limit(1)
);
if (!channel) return "Agent is not a member of this channel.";

const [agent] = await withRetry(() =>
db!
.select({ id: agents.id })
Expand Down Expand Up @@ -98,8 +107,8 @@ async function resolveRealtimeChannelAgentViolation(params: {
if (!CHANNEL_AGENT_SPEAKING_ROLES.has(member.role)) {
return "Agent cannot post in this channel.";
}
if (!CHANNEL_AGENT_SPEAKING_MODES.has(member.agentParticipationMode ?? "mention_only")) {
return "Agent is not configured to respond in this channel.";
if (channel.type !== "dm" && !CHANNEL_AGENT_ACTIVE_MODES.has(member.agentParticipationMode ?? "mention_only")) {
return "Agent is not an active participant in this channel.";
}
return null;
}
Expand Down
Loading
Loading