Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions packages/engine.io/lib/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,18 @@ export class Server extends EventEmitter<
const responseHeaders = new Headers();
if (this.opts.cors) {
addCorsHeaders(responseHeaders, this.opts.cors, req);

if (req.method === "OPTIONS") {
return new Response(null, { status: 204, headers: responseHeaders });
}
}

if (this.opts.editResponseHeaders) {
await this.opts.editResponseHeaders(responseHeaders, req, connInfo);
}

if (this.opts.cors) {
if (req.method === "OPTIONS") {
return new Response(null, { status: 204, headers: responseHeaders });
}
}

try {
await this.verify(req, url);
} catch (err) {
Expand Down Expand Up @@ -278,20 +280,26 @@ export class Server extends EventEmitter<
});
}
const previousTransport = client.transport.name;
if (previousTransport === "websocket") {
const isUpgradeRequest = req.headers.has("upgrade");
const isValidUpgrade = previousTransport === "polling" &&
transport === "websocket" &&
isUpgradeRequest;

if (
previousTransport === "websocket" ||
(!isValidUpgrade && transport !== previousTransport)
) {
getLogger("engine.io").debug(
"[server] unexpected transport without upgrade",
);
return Promise.reject(
{
code: ERROR_CODES.BAD_REQUEST,
context: {
name: "TRANSPORT_MISMATCH",
transport,
previousTransport,
},
return Promise.reject({
code: ERROR_CODES.BAD_REQUEST,
context: {
name: "TRANSPORT_MISMATCH",
transport,
previousTransport,
},
);
});
}
} else {
// handshake is GET only
Expand Down
30 changes: 30 additions & 0 deletions packages/engine.io/test/response_headers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,34 @@ describe("response headers", () => {
socket.onopen = done;
});
});

it("should send custom response headers for preflight requests", () => {
const engine = new Server({
cors: {
origin: ["https://example.com"],
},
editResponseHeaders: (responseHeaders) => {
responseHeaders.set("x-test", "123");
},
});

return setup(engine, 1, async (port, done) => {
const response = await fetch(
`http://localhost:${port}/engine.io/?EIO=4&transport=polling`,
{
method: "OPTIONS",
headers: {
origin: "https://example.com",
},
},
);

assertEquals(response.status, 204);
assertEquals(response.headers.get("x-test"), "123");

await response.body?.cancel();

done();
});
});
});
146 changes: 97 additions & 49 deletions packages/engine.io/test/verification.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,69 +177,61 @@ describe("verification", () => {
it("should disallow invalid handshake method", () => {
const engine = new Server();

return setup(
engine,
2,
async (port, partialDone) => {
engine.on("connection_error", (err) => {
assertExists(err.req);
assertEquals(err.code, 2);
assertEquals(err.message, "Bad handshake method");
assertEquals(err.context.method, "PUT");

partialDone();
});
return setup(engine, 2, async (port, partialDone) => {
engine.on("connection_error", (err) => {
assertExists(err.req);
assertEquals(err.code, 2);
assertEquals(err.message, "Bad handshake method");
assertEquals(err.context.method, "PUT");

const response = await fetch(
`http://localhost:${port}/engine.io/?transport=polling`,
{
method: "put",
},
);
partialDone();
});

assertEquals(response.status, 400);
const response = await fetch(
`http://localhost:${port}/engine.io/?transport=polling`,
{
method: "put",
},
);

const body = await response.json();
assertEquals(body.code, 2);
assertEquals(body.message, "Bad handshake method");
assertEquals(response.status, 400);

partialDone();
},
);
const body = await response.json();
assertEquals(body.code, 2);
assertEquals(body.message, "Bad handshake method");

partialDone();
});
});

it("should disallow unsupported protocol versions", () => {
const engine = new Server();

return setup(
engine,
2,
async (port, partialDone) => {
engine.on("connection_error", (err) => {
assertExists(err.req);
assertEquals(err.code, 5);
assertEquals(err.message, "Unsupported protocol version");
assertEquals(err.context.protocol, 3);
return setup(engine, 2, async (port, partialDone) => {
engine.on("connection_error", (err) => {
assertExists(err.req);
assertEquals(err.code, 5);
assertEquals(err.message, "Unsupported protocol version");
assertEquals(err.context.protocol, 3);

partialDone();
});
partialDone();
});

const response = await fetch(
`http://localhost:${port}/engine.io/?EIO=3&transport=polling`,
{
method: "get",
},
);
const response = await fetch(
`http://localhost:${port}/engine.io/?EIO=3&transport=polling`,
{
method: "get",
},
);

assertEquals(response.status, 400);
assertEquals(response.status, 400);

const body = await response.json();
assertEquals(body.code, 5);
assertEquals(body.message, "Unsupported protocol version");
const body = await response.json();
assertEquals(body.code, 5);
assertEquals(body.message, "Unsupported protocol version");

partialDone();
},
);
partialDone();
});
});

it("should disallow invalid transport", () => {
Expand Down Expand Up @@ -301,4 +293,60 @@ describe("verification", () => {
};
});
});

it("should disallow transport mismatch for an existing polling session", () => {
const engine = new Server();

return setup(engine, 2, async (port, partialDone) => {
engine.on("connection_error", (err) => {
assertExists(err.req);
assertEquals(err.code, 3);
assertEquals(err.message, "Bad request");
assertEquals(err.context.name, "TRANSPORT_MISMATCH");
assertEquals(err.context.transport, "websocket");
assertEquals(err.context.previousTransport, "polling");

partialDone();
});

const response = await fetch(
`http://localhost:${port}/engine.io/?EIO=4&transport=polling`,
{
method: "get",
},
);

const body = await response.text();
const sid = JSON.parse(body.substring(1)).sid;

let timerId: number | undefined;

const mismatchResponse = await Promise.race([
fetch(
`http://localhost:${port}/engine.io/?EIO=4&transport=websocket&sid=${sid}`,
{
method: "get",
},
),
new Promise<Response>((_, reject) => {
timerId = setTimeout(
() => reject(new Error("request timed out")),
200,
);
}),
]);

if (timerId !== undefined) {
clearTimeout(timerId);
}

assertEquals(mismatchResponse.status, 400);

const mismatchBody = await mismatchResponse.json();
assertEquals(mismatchBody.code, 3);
assertEquals(mismatchBody.message, "Bad request");

partialDone();
});
});
});
29 changes: 19 additions & 10 deletions packages/socket.io/lib/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,7 @@ export class Client<
this.decoder = decoder;
this.conn = conn;

const url = new URL(req.url);
this.handshake = {
url: url.pathname,
headers: req.headers,
query: url.searchParams,
address: (connInfo.remoteAddr as Deno.NetAddr).hostname,
secure: false,
xdomain: req.headers.has("origin"),
};
this.handshake = createHandshakeBase(req, connInfo);

conn.on("message", (data) => this.decoder.add(data));
conn.on("close", (reason) => this.onclose(reason));
Expand Down Expand Up @@ -171,7 +163,7 @@ export class Client<
_remove(
socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>,
): void {
this.sockets.delete(socket.id);
this.sockets.delete(socket.nsp.name);
}

private close() {
Expand Down Expand Up @@ -235,3 +227,20 @@ export class Client<
}
}
}

export function createHandshakeBase(
req: Request,
connInfo: Deno.ServeHandlerInfo,
): Omit<Handshake, "issued" | "time" | "auth"> {
const url = new URL(req.url);
const origin = req.headers.get("origin");

return {
url: url.pathname,
headers: req.headers,
query: url.searchParams,
address: (connInfo.remoteAddr as Deno.NetAddr).hostname,
secure: url.protocol === "https:",
xdomain: origin !== null && origin !== url.origin,
};
}
38 changes: 38 additions & 0 deletions packages/socket.io/test/client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { assertEquals, describe, it } from "../../../test_deps.ts";
import { createHandshakeBase } from "../lib/client.ts";

describe("client handshake metadata", () => {
it("should derive secure and cross-domain flags from the request URL and origin", () => {
const connInfo = {
remoteAddr: {
transport: "tcp",
hostname: "127.0.0.1",
port: 1234,
},
} as Deno.ServeHandlerInfo;

const sameOriginHandshake = createHandshakeBase(
new Request("https://example.com/socket.io/?EIO=4&transport=polling", {
headers: {
origin: "https://example.com",
},
}),
connInfo,
);

assertEquals(sameOriginHandshake.secure, true);
assertEquals(sameOriginHandshake.xdomain, false);

const crossOriginHandshake = createHandshakeBase(
new Request("https://example.com/socket.io/?EIO=4&transport=polling", {
headers: {
origin: "https://other.example.com",
},
}),
connInfo,
);

assertEquals(crossOriginHandshake.secure, true);
assertEquals(crossOriginHandshake.xdomain, true);
});
});
Loading
Loading