diff --git a/packages/engine.io/lib/server.ts b/packages/engine.io/lib/server.ts index 208b2f8..57bf6bb 100644 --- a/packages/engine.io/lib/server.ts +++ b/packages/engine.io/lib/server.ts @@ -162,16 +162,18 @@ export class Server extends EventEmitter< const responseHeaders = new Headers(); if (this.opts.cors) { addCorsHeaders(responseHeaders, this.opts.cors, req); - - if (req.method === "OPTIONS") { - return new Response(null, { status: 204, headers: responseHeaders }); - } } if (this.opts.editResponseHeaders) { await this.opts.editResponseHeaders(responseHeaders, req, connInfo); } + if (this.opts.cors) { + if (req.method === "OPTIONS") { + return new Response(null, { status: 204, headers: responseHeaders }); + } + } + try { await this.verify(req, url); } catch (err) { @@ -278,20 +280,26 @@ export class Server extends EventEmitter< }); } const previousTransport = client.transport.name; - if (previousTransport === "websocket") { + const isUpgradeRequest = req.headers.has("upgrade"); + const isValidUpgrade = previousTransport === "polling" && + transport === "websocket" && + isUpgradeRequest; + + if ( + previousTransport === "websocket" || + (!isValidUpgrade && transport !== previousTransport) + ) { getLogger("engine.io").debug( "[server] unexpected transport without upgrade", ); - return Promise.reject( - { - code: ERROR_CODES.BAD_REQUEST, - context: { - name: "TRANSPORT_MISMATCH", - transport, - previousTransport, - }, + return Promise.reject({ + code: ERROR_CODES.BAD_REQUEST, + context: { + name: "TRANSPORT_MISMATCH", + transport, + previousTransport, }, - ); + }); } } else { // handshake is GET only diff --git a/packages/engine.io/test/response_headers.test.ts b/packages/engine.io/test/response_headers.test.ts index 5818e61..19ba419 100644 --- a/packages/engine.io/test/response_headers.test.ts +++ b/packages/engine.io/test/response_headers.test.ts @@ -65,4 +65,34 @@ describe("response headers", () => { socket.onopen = done; }); }); + + it("should send custom response headers for preflight requests", () => { + const engine = new Server({ + cors: { + origin: ["https://example.com"], + }, + editResponseHeaders: (responseHeaders) => { + responseHeaders.set("x-test", "123"); + }, + }); + + return setup(engine, 1, async (port, done) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "OPTIONS", + headers: { + origin: "https://example.com", + }, + }, + ); + + assertEquals(response.status, 204); + assertEquals(response.headers.get("x-test"), "123"); + + await response.body?.cancel(); + + done(); + }); + }); }); diff --git a/packages/engine.io/test/verification.test.ts b/packages/engine.io/test/verification.test.ts index d756eaf..711c459 100644 --- a/packages/engine.io/test/verification.test.ts +++ b/packages/engine.io/test/verification.test.ts @@ -177,69 +177,61 @@ describe("verification", () => { it("should disallow invalid handshake method", () => { const engine = new Server(); - return setup( - engine, - 2, - async (port, partialDone) => { - engine.on("connection_error", (err) => { - assertExists(err.req); - assertEquals(err.code, 2); - assertEquals(err.message, "Bad handshake method"); - assertEquals(err.context.method, "PUT"); - - partialDone(); - }); + return setup(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 2); + assertEquals(err.message, "Bad handshake method"); + assertEquals(err.context.method, "PUT"); - const response = await fetch( - `http://localhost:${port}/engine.io/?transport=polling`, - { - method: "put", - }, - ); + partialDone(); + }); - assertEquals(response.status, 400); + const response = await fetch( + `http://localhost:${port}/engine.io/?transport=polling`, + { + method: "put", + }, + ); - const body = await response.json(); - assertEquals(body.code, 2); - assertEquals(body.message, "Bad handshake method"); + assertEquals(response.status, 400); - partialDone(); - }, - ); + const body = await response.json(); + assertEquals(body.code, 2); + assertEquals(body.message, "Bad handshake method"); + + partialDone(); + }); }); it("should disallow unsupported protocol versions", () => { const engine = new Server(); - return setup( - engine, - 2, - async (port, partialDone) => { - engine.on("connection_error", (err) => { - assertExists(err.req); - assertEquals(err.code, 5); - assertEquals(err.message, "Unsupported protocol version"); - assertEquals(err.context.protocol, 3); + return setup(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 5); + assertEquals(err.message, "Unsupported protocol version"); + assertEquals(err.context.protocol, 3); - partialDone(); - }); + partialDone(); + }); - const response = await fetch( - `http://localhost:${port}/engine.io/?EIO=3&transport=polling`, - { - method: "get", - }, - ); + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=3&transport=polling`, + { + method: "get", + }, + ); - assertEquals(response.status, 400); + assertEquals(response.status, 400); - const body = await response.json(); - assertEquals(body.code, 5); - assertEquals(body.message, "Unsupported protocol version"); + const body = await response.json(); + assertEquals(body.code, 5); + assertEquals(body.message, "Unsupported protocol version"); - partialDone(); - }, - ); + partialDone(); + }); }); it("should disallow invalid transport", () => { @@ -301,4 +293,60 @@ describe("verification", () => { }; }); }); + + it("should disallow transport mismatch for an existing polling session", () => { + const engine = new Server(); + + return setup(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 3); + assertEquals(err.message, "Bad request"); + assertEquals(err.context.name, "TRANSPORT_MISMATCH"); + assertEquals(err.context.transport, "websocket"); + assertEquals(err.context.previousTransport, "polling"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const body = await response.text(); + const sid = JSON.parse(body.substring(1)).sid; + + let timerId: number | undefined; + + const mismatchResponse = await Promise.race([ + fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=websocket&sid=${sid}`, + { + method: "get", + }, + ), + new Promise((_, reject) => { + timerId = setTimeout( + () => reject(new Error("request timed out")), + 200, + ); + }), + ]); + + if (timerId !== undefined) { + clearTimeout(timerId); + } + + assertEquals(mismatchResponse.status, 400); + + const mismatchBody = await mismatchResponse.json(); + assertEquals(mismatchBody.code, 3); + assertEquals(mismatchBody.message, "Bad request"); + + partialDone(); + }); + }); }); diff --git a/packages/socket.io/lib/client.ts b/packages/socket.io/lib/client.ts index b4c549d..9dc9e0e 100644 --- a/packages/socket.io/lib/client.ts +++ b/packages/socket.io/lib/client.ts @@ -49,15 +49,7 @@ export class Client< this.decoder = decoder; this.conn = conn; - const url = new URL(req.url); - this.handshake = { - url: url.pathname, - headers: req.headers, - query: url.searchParams, - address: (connInfo.remoteAddr as Deno.NetAddr).hostname, - secure: false, - xdomain: req.headers.has("origin"), - }; + this.handshake = createHandshakeBase(req, connInfo); conn.on("message", (data) => this.decoder.add(data)); conn.on("close", (reason) => this.onclose(reason)); @@ -171,7 +163,7 @@ export class Client< _remove( socket: Socket, ): void { - this.sockets.delete(socket.id); + this.sockets.delete(socket.nsp.name); } private close() { @@ -235,3 +227,20 @@ export class Client< } } } + +export function createHandshakeBase( + req: Request, + connInfo: Deno.ServeHandlerInfo, +): Omit { + const url = new URL(req.url); + const origin = req.headers.get("origin"); + + return { + url: url.pathname, + headers: req.headers, + query: url.searchParams, + address: (connInfo.remoteAddr as Deno.NetAddr).hostname, + secure: url.protocol === "https:", + xdomain: origin !== null && origin !== url.origin, + }; +} diff --git a/packages/socket.io/test/client.test.ts b/packages/socket.io/test/client.test.ts new file mode 100644 index 0000000..142c513 --- /dev/null +++ b/packages/socket.io/test/client.test.ts @@ -0,0 +1,38 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { createHandshakeBase } from "../lib/client.ts"; + +describe("client handshake metadata", () => { + it("should derive secure and cross-domain flags from the request URL and origin", () => { + const connInfo = { + remoteAddr: { + transport: "tcp", + hostname: "127.0.0.1", + port: 1234, + }, + } as Deno.ServeHandlerInfo; + + const sameOriginHandshake = createHandshakeBase( + new Request("https://example.com/socket.io/?EIO=4&transport=polling", { + headers: { + origin: "https://example.com", + }, + }), + connInfo, + ); + + assertEquals(sameOriginHandshake.secure, true); + assertEquals(sameOriginHandshake.xdomain, false); + + const crossOriginHandshake = createHandshakeBase( + new Request("https://example.com/socket.io/?EIO=4&transport=polling", { + headers: { + origin: "https://other.example.com", + }, + }), + connInfo, + ); + + assertEquals(crossOriginHandshake.secure, true); + assertEquals(crossOriginHandshake.xdomain, true); + }); +}); diff --git a/packages/socket.io/test/handshake.test.ts b/packages/socket.io/test/handshake.test.ts index 61fee3f..88e27c4 100644 --- a/packages/socket.io/test/handshake.test.ts +++ b/packages/socket.io/test/handshake.test.ts @@ -145,37 +145,63 @@ describe("handshake", () => { ); }); - it("should trigger a connection event (custom namespace)", () => { + it("should reconnect to a namespace after a client-side namespace disconnect", () => { const io = new Server(); + io.of("/custom"); - return setup( - io, - 2, - async (port, partialDone) => { - io.of("/custom").on("connection", (socket) => { - assertExists(socket.id); - partialDone(); - }); + return setup(io, 1, async (port, done) => { + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); - const response = await fetch( - `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, - { - method: "get", - }, - ); + assertEquals(response.status, 200); - assertEquals(response.status, 200); + const sid = await parseSessionID(response); - const sid = await parseSessionID(response); + await eioPush(port, sid, "40/custom,"); + const firstConnectBody = await eioPoll(port, sid); + assertEquals(firstConnectBody.startsWith("40/custom,{"), true); - await eioPush(port, sid, "40/custom,"); + await eioPush(port, sid, "41/custom,"); - const body = await eioPoll(port, sid); - assertEquals(body.startsWith("40/custom,{"), true); + await eioPush(port, sid, "40/custom,"); + const secondConnectBody = await eioPoll(port, sid); + assertEquals(secondConnectBody.startsWith("40/custom,{"), true); + + done(); + }); + }); + it("should trigger a connection event (custom namespace)", () => { + const io = new Server(); + + return setup(io, 2, async (port, partialDone) => { + io.of("/custom").on("connection", (socket) => { + assertExists(socket.id); partialDone(); - }, - ); + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + const sid = await parseSessionID(response); + + await eioPush(port, sid, "40/custom,"); + + const body = await eioPoll(port, sid); + assertEquals(body.startsWith("40/custom,{"), true); + + partialDone(); + }); }); it("should trigger a connection event (dynamic namespace)", () => {