Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 126 additions & 10 deletions src/app/api/runtimes/[id]/talk/realtime/session/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,84 @@ type RuntimeRow = {
ownerUserId: string | null;
};

type Field = { key: keyof RuntimeRow };
type Predicate = (row: RuntimeRow) => boolean;
type AgentRow = {
id: string;
callsign: string;
};

const { mockRuntimeRows, mockGetGatewayClientForRuntime } = vi.hoisted(() => ({
type ChannelMemberRow = {
id: string;
channelId: string;
memberType: "user" | "agent";
agentId: string | null;
role: string;
agentParticipationMode: string | null;
};

type DbRow = RuntimeRow | AgentRow | ChannelMemberRow;
type Field = { table: string; key: string };
type Predicate = (row: DbRow) => boolean;

const { mockRuntimeRows, mockAgentRows, mockChannelMemberRows, mockGetGatewayClientForRuntime } = vi.hoisted(() => ({
mockRuntimeRows: [] as RuntimeRow[],
mockAgentRows: [] as AgentRow[],
mockChannelMemberRows: [] as ChannelMemberRow[],
mockGetGatewayClientForRuntime: vi.fn(),
}));

vi.mock("@/db/schema", () => ({
companyRuntimes: {
id: { key: "id" },
ownerUserId: { key: "ownerUserId" },
__table: "companyRuntimes",
id: { table: "companyRuntimes", key: "id" },
ownerUserId: { table: "companyRuntimes", key: "ownerUserId" },
},
agents: {
__table: "agents",
id: { table: "agents", key: "id" },
callsign: { table: "agents", key: "callsign" },
},
channelMembers: {
__table: "channelMembers",
id: { table: "channelMembers", key: "id" },
channelId: { table: "channelMembers", key: "channelId" },
memberType: { table: "channelMembers", key: "memberType" },
agentId: { table: "channelMembers", key: "agentId" },
role: { table: "channelMembers", key: "role" },
agentParticipationMode: { table: "channelMembers", key: "agentParticipationMode" },
},
}));

vi.mock("drizzle-orm", () => ({
eq: (field: Field, value: unknown): Predicate => (row) => row[field.key] === value,
eq: (field: Field, value: unknown): Predicate => (row) => (row as Record<string, unknown>)[field.key] === value,
and: (...predicates: Array<Predicate | undefined>): Predicate => (row) =>
predicates.every((predicate) => predicate?.(row) ?? true),
}));

function rowsForTable(table: { __table: string }) {
if (table.__table === "companyRuntimes") return mockRuntimeRows;
if (table.__table === "agents") return mockAgentRows;
if (table.__table === "channelMembers") return mockChannelMemberRows;
return [];
}

function projectRows(rows: DbRow[], selection: Record<string, Field>) {
return rows.map((row) =>
Object.fromEntries(
Object.entries(selection).map(([key, field]) => [
key,
(row as Record<string, unknown>)[field.key],
]),
),
);
}

vi.mock("@/db", () => ({
db: {
select: () => ({
from: () => ({
select: (selection: Record<string, Field>) => ({
from: (table: { __table: string }) => ({
where: (predicate: Predicate) => ({
limit: (count: number) => Promise.resolve(mockRuntimeRows.filter(predicate).slice(0, count)),
limit: (count: number) =>
Promise.resolve(projectRows(rowsForTable(table).filter(predicate).slice(0, count), selection)),
}),
}),
}),
Expand All @@ -41,7 +92,7 @@ vi.mock("@/db", () => ({

vi.mock("@/lib/agent-access", () => ({
getAgentAccessContext: () => ({ userId: "user_1", activeCompanyId: null, memberships: [] }),
buildRuntimeReadWhere: () => (row: RuntimeRow) => row.ownerUserId === "user_1",
buildRuntimeReadWhere: () => (row: DbRow) => (row as RuntimeRow).ownerUserId === "user_1",
}));

vi.mock("@/lib/gateway-chat-pool", () => ({
Expand All @@ -54,6 +105,8 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => {
beforeEach(() => {
vi.clearAllMocks();
mockRuntimeRows.length = 0;
mockAgentRows.length = 0;
mockChannelMemberRows.length = 0;
});

it("proxies realtime talk session requests through an accessible runtime", async () => {
Expand Down Expand Up @@ -124,6 +177,69 @@ describe("POST /api/runtimes/[id]/talk/realtime/session", () => {
}));
});

it("rejects channel-scoped realtime sessions for agents outside the channel", async () => {
const realtimeTalkSession = vi.fn();
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });
mockAgentRows.push({ id: "agent_1", callsign: "neo" });
mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeTalkSession });

const response = await POST(
new Request("http://localhost/api/runtimes/rt_1/talk/realtime/session", {
method: "POST",
body: JSON.stringify({
agentId: "main",
channelAgentId: "neo",
channelId: "channel_crew",
}),
}),
{ params: Promise.resolve({ id: "rt_1" }) },
);

expect(response.status).toBe(403);
await expect(response.json()).resolves.toEqual({
error: "Agent is not a member of this channel.",
});
expect(mockGetGatewayClientForRuntime).not.toHaveBeenCalled();
expect(realtimeTalkSession).not.toHaveBeenCalled();
});

it("allows channel-scoped realtime sessions for eligible channel agents", async () => {
const realtimeTalkSession = vi.fn().mockResolvedValue({
transport: "gateway-relay",
relaySessionId: "relay_1",
});
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });
mockAgentRows.push({ id: "agent_1", callsign: "neo" });
mockChannelMemberRows.push({
id: "member_1",
channelId: "channel_crew",
memberType: "agent",
agentId: "agent_1",
role: "member",
agentParticipationMode: "mention_only",
});
mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeTalkSession });

const response = await POST(
new Request("http://localhost/api/runtimes/rt_1/talk/realtime/session", {
method: "POST",
body: JSON.stringify({
sessionKey: "main",
agentId: "main",
channelAgentId: "neo",
channelId: "channel_crew",
}),
}),
{ params: Promise.resolve({ id: "rt_1" }) },
);

expect(response.status).toBe(200);
expect(realtimeTalkSession).toHaveBeenCalledWith(expect.objectContaining({
agentId: "main",
sessionKey: "main",
}));
});

it("does not call the gateway for unreadable runtimes", async () => {
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_2" });

Expand Down
55 changes: 54 additions & 1 deletion src/app/api/runtimes/[id]/talk/realtime/session/route.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import { NextResponse } from "next/server";
import { and, eq } from "drizzle-orm";
import { db, withRetry } from "@/db";
import { companyRuntimes } from "@/db/schema";
import { agents, channelMembers, companyRuntimes } from "@/db/schema";
import { buildRuntimeReadWhere, getAgentAccessContext } from "@/lib/agent-access";
import { getGatewayClientForRuntime } from "@/lib/gateway-chat-pool";

export const dynamic = "force-dynamic";

const REALTIME_SLOW_SPEECH_SILENCE_MS = 2000;
const REALTIME_SLOW_SPEECH_PREFIX_PADDING_MS = 500;
const CHANNEL_AGENT_SPEAKING_ROLES = new Set(["owner", "admin", "member", "contributor"]);
const CHANNEL_AGENT_SPEAKING_MODES = new Set(["mention_only", "proactive", "on_call"]);

export async function POST(
request: Request,
Expand All @@ -32,6 +34,16 @@ export async function POST(
if (!runtime) return NextResponse.json({ error: "Runtime not found" }, { status: 404 });

const body = await request.json().catch(() => ({}));
const channelId = readOptionalString(body.channelId);
const channelAgentId = readOptionalString(body.channelAgentId) ?? readOptionalString(body.agentId);
if (channelId) {
const violation = await resolveRealtimeChannelAgentViolation({
channelId,
agentCallsign: channelAgentId,
});
if (violation) return NextResponse.json({ error: violation }, { status: 403 });
}

const client = await getGatewayClientForRuntime(runtime.id);
const session = await client.realtimeTalkSession({
sessionKey: readOptionalString(body.sessionKey),
Expand All @@ -51,6 +63,47 @@ export async function POST(
}
}

async function resolveRealtimeChannelAgentViolation(params: {
channelId: string;
agentCallsign?: string;
}) {
const callsign = params.agentCallsign?.trim();
if (!callsign) return "Channel agent mention is required.";

const [agent] = await withRetry(() =>
db!
.select({ id: agents.id })
.from(agents)
.where(eq(agents.callsign, callsign))
.limit(1)
);
if (!agent) return "Agent is not a member of this channel.";

const [member] = await withRetry(() =>
db!
.select({
role: channelMembers.role,
agentParticipationMode: channelMembers.agentParticipationMode,
})
.from(channelMembers)
.where(and(
eq(channelMembers.channelId, params.channelId),
eq(channelMembers.memberType, "agent"),
eq(channelMembers.agentId, agent.id),
))
.limit(1)
);

if (!member) return "Agent is not a member of this channel.";
if (!CHANNEL_AGENT_SPEAKING_ROLES.has(member.role)) {
return "Agent cannot post in this channel.";
}
if (!CHANNEL_AGENT_SPEAKING_MODES.has(member.agentParticipationMode ?? "mention_only")) {
return "Agent is not configured to respond in this channel.";
}
return null;
}

function readOptionalString(value: unknown) {
return typeof value === "string" && value.trim().length > 0 ? value.trim() : undefined;
}
Expand Down
Loading
Loading