diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 40389295..f04b2e1b 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -15,6 +15,7 @@ "dompurify": "^3.3.1", "jose": "^6.1.3", "lucide-react": "^0.576.0", + "qrcode.react": "^4.2.0", "react": "^19.2.0", "react-dom": "^19.2.0", "react-router": "^7.13.1", @@ -31,9 +32,11 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.5.0", "globals": "^17.0.0", + "happy-dom": "^20.7.0", "typescript": "~5.9.3", "typescript-eslint": "^8.48.0", - "vite": "^7.3.1" + "vite": "^7.3.1", + "vitest": "^4.0.18" } }, "node_modules/@babel/code-frame": { @@ -1353,6 +1356,13 @@ "win32" ] }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "dev": true, + "license": "MIT" + }, "node_modules/@tailwindcss/node": { "version": "4.2.1", "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.2.1.tgz", @@ -1713,6 +1723,24 @@ "@babel/types": "^7.28.2" } }, + "node_modules/@types/chai": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz", + "integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/deep-eql": "*", + "assertion-error": "^2.0.1" + } + }, + "node_modules/@types/deep-eql": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", + "integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/estree": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", @@ -1763,6 +1791,23 @@ "license": "MIT", "optional": true }, + "node_modules/@types/whatwg-mimetype": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/whatwg-mimetype/-/whatwg-mimetype-3.0.2.tgz", + "integrity": "sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.56.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.56.1.tgz", @@ -2079,6 +2124,117 @@ "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, + "node_modules/@vitest/expect": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.0.18.tgz", + "integrity": "sha512-8sCWUyckXXYvx4opfzVY03EOiYVxyNrHS5QxX3DAIi5dpJAAkyJezHCP77VMX4HKA2LDT/Jpfo8i2r5BE3GnQQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@standard-schema/spec": "^1.0.0", + "@types/chai": "^5.2.2", + "@vitest/spy": "4.0.18", + "@vitest/utils": "4.0.18", + "chai": "^6.2.1", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/mocker": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.0.18.tgz", + "integrity": "sha512-HhVd0MDnzzsgevnOWCBj5Otnzobjy5wLBe4EdeeFGv8luMsGcYqDuFRMcttKWZA5vVO8RFjexVovXvAM4JoJDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "4.0.18", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.21" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^6.0.0 || ^7.0.0-0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "node_modules/@vitest/pretty-format": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.0.18.tgz", + "integrity": "sha512-P24GK3GulZWC5tz87ux0m8OADrQIUVDPIjjj65vBXYG17ZeU3qD7r+MNZ1RNv4l8CGU2vtTRqixrOi9fYk/yKw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/runner": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.0.18.tgz", + "integrity": "sha512-rpk9y12PGa22Jg6g5M3UVVnTS7+zycIGk9ZNGN+m6tZHKQb7jrP7/77WfZy13Y/EUDd52NDsLRQhYKtv7XfPQw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "4.0.18", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/snapshot": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.0.18.tgz", + "integrity": "sha512-PCiV0rcl7jKQjbgYqjtakly6T1uwv/5BQ9SwBLekVg/EaYeQFPiXcgrC2Y7vDMA8dM1SUEAEV82kgSQIlXNMvA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.0.18", + "magic-string": "^0.30.21", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.0.18.tgz", + "integrity": "sha512-cbQt3PTSD7P2OARdVW3qWER5EGq7PHlvE+QfzSC0lbwO+xnt7+XH06ZzFjFRgzUX//JmpxrCu92VdwvEPlWSNw==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/utils": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.0.18.tgz", + "integrity": "sha512-msMRKLMVLWygpK3u2Hybgi4MNjcYJvwTb0Ru09+fOyCXIgT5raYP041DRRdiJiI3k/2U6SEbAETB3YtBrUkCFA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.0.18", + "tinyrainbow": "^3.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, "node_modules/acorn": { "version": "8.16.0", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.16.0.tgz", @@ -2142,6 +2298,16 @@ "dev": true, "license": "Python-2.0" }, + "node_modules/assertion-error": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", + "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -2238,6 +2404,16 @@ ], "license": "CC-BY-4.0" }, + "node_modules/chai": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/chai/-/chai-6.2.2.tgz", + "integrity": "sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -2408,6 +2584,26 @@ "node": ">=10.13.0" } }, + "node_modules/entities": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/entities/-/entities-7.0.1.tgz", + "integrity": "sha512-TWrgLOFUQTH994YUyl1yT4uyavY5nNB5muff+RtWaqNVCAK408b5ZnnbNAUEWLTCpum9w6arT70i1XdQ4UeOPA==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/es-module-lexer": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz", + "integrity": "sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==", + "dev": true, + "license": "MIT" + }, "node_modules/esbuild": { "version": "0.27.3", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.3.tgz", @@ -2646,6 +2842,16 @@ "node": ">=4.0" } }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -2656,6 +2862,16 @@ "node": ">=0.10.0" } }, + "node_modules/expect-type": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.3.0.tgz", + "integrity": "sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -2783,9 +2999,9 @@ } }, "node_modules/globals": { - "version": "17.4.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-17.4.0.tgz", - "integrity": "sha512-hjrNztw/VajQwOLsMNT1cbJiH2muO3OROCHnbehc8eY5JyD2gqz4AcMHPqgaOR59DjgUjYAYLeH699g/eWi2jw==", + "version": "17.3.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-17.3.0.tgz", + "integrity": "sha512-yMqGUQVVCkD4tqjOJf3TnrvaaHDMYp4VlUSObbkIiuCPe/ofdMBFIAcBbCSRFWOnos6qRiTVStDwqPLUclaxIw==", "dev": true, "license": "MIT", "engines": { @@ -2801,6 +3017,24 @@ "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", "license": "ISC" }, + "node_modules/happy-dom": { + "version": "20.7.0", + "resolved": "https://registry.npmjs.org/happy-dom/-/happy-dom-20.7.0.tgz", + "integrity": "sha512-hR/uLYQdngTyEfxnOoa+e6KTcfBFyc1hgFj/Cc144A5JJUuHFYqIEBDcD4FeGqUeKLRZqJ9eN9u7/GDjYEgS1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": ">=20.0.0", + "@types/whatwg-mimetype": "^3.0.2", + "@types/ws": "^8.18.1", + "entities": "^7.0.1", + "whatwg-mimetype": "^3.0.0", + "ws": "^8.18.3" + }, + "engines": { + "node": ">=20.0.0" + } + }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", @@ -3356,6 +3590,17 @@ "dev": true, "license": "MIT" }, + "node_modules/obug": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/obug/-/obug-2.1.1.tgz", + "integrity": "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/sxzz", + "https://opencollective.com/debug" + ], + "license": "MIT" + }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", @@ -3439,6 +3684,13 @@ "node": ">=8" } }, + "node_modules/pathe": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", + "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", + "dev": true, + "license": "MIT" + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -3505,6 +3757,15 @@ "node": ">=6" } }, + "node_modules/qrcode.react": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/qrcode.react/-/qrcode.react-4.2.0.tgz", + "integrity": "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA==", + "license": "ISC", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/react": { "version": "19.2.4", "resolved": "https://registry.npmjs.org/react/-/react-19.2.4.tgz", @@ -3657,6 +3918,13 @@ "node": ">=8" } }, + "node_modules/siginfo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz", + "integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==", + "dev": true, + "license": "ISC" + }, "node_modules/source-map-js": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", @@ -3666,6 +3934,20 @@ "node": ">=0.10.0" } }, + "node_modules/stackback": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", + "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", + "dev": true, + "license": "MIT" + }, + "node_modules/std-env": { + "version": "3.10.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.10.0.tgz", + "integrity": "sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==", + "dev": true, + "license": "MIT" + }, "node_modules/strip-json-comments": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", @@ -3721,6 +4003,23 @@ "url": "https://opencollective.com/webpack" } }, + "node_modules/tinybench": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", + "integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-1.0.2.tgz", + "integrity": "sha512-W/KYk+NFhkmsYpuHq5JykngiOCnxeVL8v8dFnqxSD8qEEdRfXk1SDM6JzNqcERbcGYj9tMrDQBYV9cjgnunFIg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -3737,6 +4036,16 @@ "url": "https://github.com/sponsors/SuperchupuDev" } }, + "node_modules/tinyrainbow": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-3.0.3.tgz", + "integrity": "sha512-PSkbLUoxOFRzJYjjxHJt9xro7D+iilgMX/C9lawzVuYiIdcihh9DXmVibBe8lmcFrRi/VzlPjBxbN7rH24q8/Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/ts-api-utils": { "version": "2.4.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.4.0.tgz", @@ -3923,6 +4232,94 @@ } } }, + "node_modules/vitest": { + "version": "4.0.18", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.0.18.tgz", + "integrity": "sha512-hOQuK7h0FGKgBAas7v0mSAsnvrIgAvWmRFjmzpJ7SwFHH3g1k2u37JtYwOwmEKhK6ZO3v9ggDBBm0La1LCK4uQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "4.0.18", + "@vitest/mocker": "4.0.18", + "@vitest/pretty-format": "4.0.18", + "@vitest/runner": "4.0.18", + "@vitest/snapshot": "4.0.18", + "@vitest/spy": "4.0.18", + "@vitest/utils": "4.0.18", + "es-module-lexer": "^1.7.0", + "expect-type": "^1.2.2", + "magic-string": "^0.30.21", + "obug": "^2.1.1", + "pathe": "^2.0.3", + "picomatch": "^4.0.3", + "std-env": "^3.10.0", + "tinybench": "^2.9.0", + "tinyexec": "^1.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.0.3", + "vite": "^6.0.0 || ^7.0.0", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^20.0.0 || ^22.0.0 || >=24.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@opentelemetry/api": "^1.9.0", + "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", + "@vitest/browser-playwright": "4.0.18", + "@vitest/browser-preview": "4.0.18", + "@vitest/browser-webdriverio": "4.0.18", + "@vitest/ui": "4.0.18", + "happy-dom": "*", + "jsdom": "*" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@opentelemetry/api": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser-playwright": { + "optional": true + }, + "@vitest/browser-preview": { + "optional": true + }, + "@vitest/browser-webdriverio": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + } + } + }, + "node_modules/whatwg-mimetype": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-3.0.0.tgz", + "integrity": "sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -3939,6 +4336,23 @@ "node": ">= 8" } }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", + "integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/word-wrap": { "version": "1.2.5", "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", @@ -3949,6 +4363,28 @@ "node": ">=0.10.0" } }, + "node_modules/ws": { + "version": "8.19.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz", + "integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/yallist": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", diff --git a/frontend/package.json b/frontend/package.json index 9b46d6b2..37835771 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -17,6 +17,7 @@ "dompurify": "^3.3.1", "jose": "^6.1.3", "lucide-react": "^0.576.0", + "qrcode.react": "^4.2.0", "react": "^19.2.0", "react-dom": "^19.2.0", "react-router": "^7.13.1", @@ -33,8 +34,10 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.5.0", "globals": "^17.0.0", + "happy-dom": "^20.7.0", "typescript": "~5.9.3", "typescript-eslint": "^8.48.0", - "vite": "^7.3.1" + "vite": "^7.3.1", + "vitest": "^4.0.18" } } diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 3abb7921..27df9f26 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,5 @@ import { useCallback, useEffect, useRef, useState } from "react" -import { useSearchParams } from "react-router" +import { Routes, Route, useSearchParams } from "react-router" import type { AppConfig } from "@/lib/types" import { checkBrowserCompatibility } from "@/lib/browser-check" import { loadConfig } from "@/lib/config" @@ -10,6 +10,8 @@ import { ErrorPage } from "@/components/ErrorPage" import { BrowserWarning } from "@/components/BrowserWarning" import { SignInPage } from "@/components/SignInPage" import { DashboardPage } from "@/components/DashboardPage" +import { PairAuthorityPage } from "@/components/PairAuthorityPage" +import { PairSupplicantPage } from "@/components/PairSupplicantPage" import { ThemeProvider } from "@/components/theme-provider" import { ThemeToggle } from "@/components/theme-toggle" @@ -140,11 +142,57 @@ function MainFlow() { ) } +function PairRoute({ + Component, +}: { + Component: React.ComponentType<{ config: AppConfig }> +}) { + const [config, setConfig] = useState(null) + const [error, setError] = useState(null) + const initialized = useRef(false) + + useEffect(() => { + if (initialized.current) return + initialized.current = true + loadConfig() + .then(setConfig) + .catch((err) => + setError(err instanceof Error ? err.message : String(err)) + ) + }, []) + + if (error) { + return ( + window.location.reload()} + /> + ) + } + + if (!config) { + return + } + + return +} + export default function App() { return ( - + + } + /> + } + /> + } /> + ) diff --git a/frontend/src/components/PairAuthorityPage.tsx b/frontend/src/components/PairAuthorityPage.tsx new file mode 100644 index 00000000..0ad47623 --- /dev/null +++ b/frontend/src/components/PairAuthorityPage.tsx @@ -0,0 +1,330 @@ +import { useCallback, useEffect, useRef, useState } from "react" +import { QRCodeSVG } from "qrcode.react" +import { Loader2, CheckCircle2, XCircle, Smartphone } from "lucide-react" +import { Button } from "@/components/ui/button" +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card" +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert" +import type { AppConfig } from "@/lib/types" +import { PairingChannel } from "@/lib/pairing-channel" +import { buildPairUrl } from "@/lib/pairing" +import { requestOAuthCode } from "@/lib/auth-client" +import { generatePKCE } from "@/lib/pkce" +import { + listenFromFirefox, + sendPairComplete, + sendPairDecline, +} from "@/lib/webchannel" +import * as session from "@/lib/session" + +type PairAuthorityState = + | { step: "creating-channel" } + | { step: "showing-qr"; pairUrl: string; channel: PairingChannel } + | { + step: "confirming" + channel: PairingChannel + suppRequest: SuppRequest + } + | { step: "authorizing" } + | { step: "complete" } + | { step: "error"; message: string } + +interface SuppRequest { + client_id: string + state: string + scope: string + code_challenge: string + code_challenge_method: string + keys_jwk?: string +} + +interface PairAuthorityPageProps { + config: AppConfig +} + +export function PairAuthorityPage({ config }: PairAuthorityPageProps) { + const [state, setState] = useState({ + step: "creating-channel", + }) + const initialized = useRef(false) + const channelRef = useRef(null) + + const handleError = useCallback((msg: string) => { + console.error(`[ffsync:pair] Authority error: ${msg}`) + setState({ step: "error", message: msg }) + }, []) + + useEffect(() => { + if (initialized.current) return + initialized.current = true + + const cleanup = listenFromFirefox((command, _data, messageId) => { + if (command === "fxaccounts:pair_decline") { + sendPairDecline(messageId) + if (channelRef.current && !channelRef.current.closed) { + channelRef.current.close().catch(() => {}) + } + setState({ step: "error", message: "Pairing was declined." }) + } else if (command === "fxaccounts:pair_complete") { + sendPairComplete(messageId) + } + }) + + initChannel() + + return () => { + cleanup() + if (channelRef.current && !channelRef.current.closed) { + channelRef.current.close().catch(() => {}) + } + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + async function initChannel() { + if (!config.pairingServerUrl) { + handleError( + "pairingServerUrl is not configured. Add it to config.json." + ) + return + } + + try { + const channel = await PairingChannel.create(config.pairingServerUrl) + channelRef.current = channel + + const contentUrl = config.redirectUri.replace(/\/+$/, "") + const pairUrl = buildPairUrl( + contentUrl, + channel.channelId, + channel.channelKey + ) + + setState({ step: "showing-qr", pairUrl, channel }) + + // Listen for supplicant request + channel.addEventListener("message", (event: Event) => { + const detail = (event as CustomEvent).detail + const data = detail.data + if (data.message === "pair:supp:request") { + setState({ + step: "confirming", + channel, + suppRequest: data.data as SuppRequest, + }) + // Send authority metadata back + const auth = session.getAuth() + channel.send({ + message: "pair:auth:metadata", + data: { + email: auth?.email ?? "unknown", + uid: auth?.uid ?? "unknown", + }, + }) + } + }) + + channel.addEventListener("error", () => { + handleError("Pairing channel connection error.") + }) + + channel.addEventListener("close", () => { + // Peer closed the channel + }) + } catch (err) { + handleError(err instanceof Error ? err.message : String(err)) + } + } + + async function handleApprove() { + if (state.step !== "confirming") return + const { channel, suppRequest } = state + + setState({ step: "authorizing" }) + + try { + const auth = session.getAuth() + if (!auth || !config.authServerUrl) { + handleError( + "You must be signed in to approve pairing. Please sign in first." + ) + return + } + + const redirectUri = + "urn:ietf:wg:oauth:2.0:oob:pair-auth-webchannel" + + // Generate our own PKCE if needed, or use the supplicant's + let codeChallenge = suppRequest.code_challenge + if (!codeChallenge) { + const pkce = await generatePKCE() + codeChallenge = pkce.codeChallenge + } + + const oauthResult = await requestOAuthCode( + config.authServerUrl, + auth.sessionToken, + suppRequest.client_id || config.clientId, + suppRequest.scope || + "https://identity.mozilla.com/apps/oldsync profile", + suppRequest.state || crypto.randomUUID(), + codeChallenge, + undefined, + redirectUri + ) + + // Send the OAuth code through the pairing channel + await channel.send({ + message: "pair:auth:authorize", + data: { + code: oauthResult.code, + state: oauthResult.state, + redirect: oauthResult.redirect, + }, + }) + + setState({ step: "complete" }) + } catch (err) { + handleError(err instanceof Error ? err.message : String(err)) + } + } + + function handleDecline() { + if (state.step !== "confirming") return + const { channel } = state + channel.send({ message: "pair:auth:decline", data: {} }).catch(() => {}) + channel.close().catch(() => {}) + setState({ step: "error", message: "Pairing declined." }) + } + + if (state.step === "creating-channel") { + return ( + + + +

+ Creating pairing channel... +

+
+
+ ) + } + + if (state.step === "showing-qr") { + return ( + + + Pair a Device + + Scan this QR code with the device you want to pair. + + + +
+ +
+

+ Waiting for the other device to connect... +

+
+
+ ) + } + + if (state.step === "confirming") { + return ( + + +
+ + Confirm Pairing +
+ + A device wants to connect to your account. + +
+ +
+

Requesting access to:

+

+ {state.suppRequest.scope || "Sync data"} +

+
+
+ + +
+
+
+ ) + } + + if (state.step === "authorizing") { + return ( + + + +

+ Authorizing device... +

+
+
+ ) + } + + if (state.step === "complete") { + return ( + + +
+ + Device Paired +
+ + The device has been successfully paired. You can close this page. + +
+
+ ) + } + + if (state.step === "error") { + return ( + + +
+ + Pairing Error +
+
+ + + Error + {state.message} + + + +
+ ) + } + + return null +} diff --git a/frontend/src/components/PairSupplicantPage.tsx b/frontend/src/components/PairSupplicantPage.tsx new file mode 100644 index 00000000..92436db7 --- /dev/null +++ b/frontend/src/components/PairSupplicantPage.tsx @@ -0,0 +1,238 @@ +import { useEffect, useRef, useState } from "react" +import { Loader2, CheckCircle2, XCircle } from "lucide-react" +import { Button } from "@/components/ui/button" +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card" +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert" +import type { AppConfig } from "@/lib/types" +import { PairingChannel } from "@/lib/pairing-channel" +import { parsePairFragment } from "@/lib/pairing" +import { generatePKCE } from "@/lib/pkce" + +type PairSupplicantState = + | { step: "connecting" } + | { step: "waiting" } + | { + step: "confirming" + email: string + uid: string + } + | { step: "complete"; code: string; state: string } + | { step: "error"; message: string } + +interface PairSupplicantPageProps { + config: AppConfig +} + +export function PairSupplicantPage({ config }: PairSupplicantPageProps) { + const [state, setState] = useState({ + step: "connecting", + }) + const initialized = useRef(false) + const channelRef = useRef(null) + + useEffect(() => { + if (initialized.current) return + initialized.current = true + + initConnection() + + return () => { + if (channelRef.current && !channelRef.current.closed) { + channelRef.current.close().catch(() => {}) + } + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + async function initConnection() { + if (!config.pairingServerUrl) { + setState({ + step: "error", + message: + "pairingServerUrl is not configured. Add it to config.json.", + }) + return + } + + const fragment = window.location.hash + const parsed = parsePairFragment(fragment) + if (!parsed) { + setState({ + step: "error", + message: + "Invalid pairing link. Missing channel_id or channel_key in URL fragment.", + }) + return + } + + const { channelId, channelKey } = parsed + + try { + const channel = await PairingChannel.connect( + config.pairingServerUrl, + channelId, + channelKey + ) + channelRef.current = channel + + // Generate PKCE for the OAuth flow + const pkce = await generatePKCE() + + // Send the supplicant request + await channel.send({ + message: "pair:supp:request", + data: { + client_id: config.clientId, + state: crypto.randomUUID(), + scope: + "https://identity.mozilla.com/apps/oldsync profile", + code_challenge: pkce.codeChallenge, + code_challenge_method: "S256", + }, + }) + + setState({ step: "waiting" }) + + // Listen for authority messages + channel.addEventListener("message", (event: Event) => { + const detail = (event as CustomEvent).detail + const data = detail.data + + if (data.message === "pair:auth:metadata") { + setState({ + step: "confirming", + email: data.data.email, + uid: data.data.uid, + }) + } else if (data.message === "pair:auth:authorize") { + setState({ + step: "complete", + code: data.data.code, + state: data.data.state, + }) + channel.close().catch(() => {}) + } else if (data.message === "pair:auth:decline") { + setState({ + step: "error", + message: "Pairing was declined by the other device.", + }) + channel.close().catch(() => {}) + } + }) + + channel.addEventListener("error", () => { + setState({ + step: "error", + message: "Pairing channel connection error.", + }) + }) + + channel.addEventListener("close", () => { + // Peer closed the channel + }) + } catch (err) { + setState({ + step: "error", + message: err instanceof Error ? err.message : String(err), + }) + } + } + + if (state.step === "connecting") { + return ( + + + +

+ Connecting to pairing channel... +

+
+
+ ) + } + + if (state.step === "waiting") { + return ( + + + +

+ Waiting for the other device to approve... +

+
+
+ ) + } + + if (state.step === "confirming") { + return ( + + + Pairing in Progress + + Connecting to the account below. Please confirm on the other + device. + + + +
+

Account

+

{state.email}

+
+

+ Waiting for approval on the other device... +

+
+
+ ) + } + + if (state.step === "complete") { + return ( + + +
+ + Pairing Complete +
+ + This device has been successfully paired. You can close this page. + +
+
+ ) + } + + if (state.step === "error") { + return ( + + +
+ + Pairing Error +
+
+ + + Error + {state.message} + + + +
+ ) + } + + return null +} diff --git a/frontend/src/lib/auth-client.ts b/frontend/src/lib/auth-client.ts index 32c8fbd3..ec26720c 100644 --- a/frontend/src/lib/auth-client.ts +++ b/frontend/src/lib/auth-client.ts @@ -122,7 +122,8 @@ export async function requestOAuthCode( scope: string, state: string, codeChallenge: string, - keysJwe?: string + keysJwe?: string, + redirectUri?: string ): Promise { const url = `${authServerUrl}/v1/oauth/authorization` const authorization = await buildHawkHeader(sessionToken, "POST", url) @@ -137,6 +138,9 @@ export async function requestOAuthCode( if (keysJwe) { bodyObj.keys_jwe = keysJwe } + if (redirectUri) { + bodyObj.redirect_uri = redirectUri + } return authFetch( url, { diff --git a/frontend/src/lib/config.ts b/frontend/src/lib/config.ts index 57187605..e58997d6 100644 --- a/frontend/src/lib/config.ts +++ b/frontend/src/lib/config.ts @@ -39,6 +39,9 @@ export async function loadConfig(): Promise { if (config.authServerUrl) { config.authServerUrl = config.authServerUrl.replace(/\/+$/, "") } + if (config.pairingServerUrl) { + config.pairingServerUrl = config.pairingServerUrl.replace(/\/+$/, "") + } return config } diff --git a/frontend/src/lib/pairing-channel/__tests__/helpers.ts b/frontend/src/lib/pairing-channel/__tests__/helpers.ts new file mode 100644 index 00000000..98fde34c --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/helpers.ts @@ -0,0 +1,434 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { expect } from "vitest" +import { + BufferWriter, + arrayToBytes, + bytesAreEqual, + zeros, +} from "../utils" +import { HASH_LENGTH } from "../crypto" +import { KeySchedule } from "../keyschedule" +import { EncryptionState, DecryptionState } from "../recordlayer" +import { TEST_VECTORS } from "./test-vectors" + +export const testHelpers = { + nextTick: function (): Promise { + return new Promise((res) => setTimeout(res, 1)) + }, + + tamper: function (bytes: Uint8Array, where = 0): Uint8Array { + const tampered = bytes.slice() + tampered[where] += 1 + expect(bytesAreEqual(bytes, tampered)).toBe(false) + return tampered + }, + + decryptInnerPlaintext: async function ( + cipherstate: DecryptionState, + bytes: Uint8Array + ): Promise { + return await cipherstate.decrypt(bytes.slice(5), bytes.slice(0, 5)) + }, + + makePlaintextRecord: function (opts: { + type?: number + version?: number + content?: Uint8Array + contentLength?: number + trailer?: Uint8Array + }): Uint8Array { + const buf = new BufferWriter() + const content = + typeof opts.content !== "undefined" + ? opts.content + : arrayToBytes([1, 2, 3, 4, 5]) + buf.writeUint8(typeof opts.type !== "undefined" ? opts.type : 22) + buf.writeUint16( + typeof opts.version !== "undefined" ? opts.version : 0x0303 + ) + buf.writeUint16( + typeof opts.contentLength !== "undefined" + ? opts.contentLength + : content.byteLength + ) + buf.writeBytes(content) + if (typeof opts.trailer !== "undefined") { + buf.writeBytes(opts.trailer) + } + return buf.flush() + }, + + makeEncryptedInnerPlaintext: async function ( + cipherstate: EncryptionState, + opts: { + content?: Uint8Array + innerPlaintext?: Uint8Array + type?: number + outerType?: number + outerVersion?: number + ciphertextLength?: number + padding?: number + } + ): Promise { + const adBuf = new BufferWriter() + const innerPlaintextBuf = new BufferWriter() + const plaintext = + typeof opts.content !== "undefined" + ? opts.content + : arrayToBytes([1, 2, 3, 4, 5]) + if (typeof opts.innerPlaintext !== "undefined") { + innerPlaintextBuf.writeBytes(opts.innerPlaintext) + } else { + innerPlaintextBuf.writeBytes(plaintext) + innerPlaintextBuf.writeUint8( + typeof opts.type !== "undefined" ? opts.type : 23 + ) + if (opts.padding) { + innerPlaintextBuf.writeBytes(zeros(opts.padding)) + } + } + const ciphertextLength = innerPlaintextBuf.tell() + 16 + adBuf.writeUint8( + typeof opts.outerType !== "undefined" ? opts.outerType : 23 + ) + adBuf.writeUint16( + typeof opts.outerVersion !== "undefined" ? opts.outerVersion : 0x0303 + ) + adBuf.writeUint16( + typeof opts.ciphertextLength !== "undefined" + ? opts.ciphertextLength + : ciphertextLength + ) + const ciphertext = await cipherstate.encrypt( + innerPlaintextBuf.flush(), + adBuf.flush() + ) + return ciphertext + }, + + makeEncryptedRecord: async function ( + cipherstate: EncryptionState, + opts: { + content?: Uint8Array + innerPlaintext?: Uint8Array + type?: number + outerType?: number + outerVersion?: number + outerContentLength?: number + outerTrailer?: Uint8Array + ciphertextLength?: number + ciphertext?: Uint8Array + padding?: number + } + ): Promise { + let ciphertext = opts.ciphertext + if (typeof ciphertext === "undefined") { + ciphertext = await testHelpers.makeEncryptedInnerPlaintext( + cipherstate, + opts + ) + } + return testHelpers.makePlaintextRecord({ + content: ciphertext, + contentLength: opts.outerContentLength, + trailer: opts.outerTrailer, + type: + typeof opts.outerType !== "undefined" ? opts.outerType : 23, + version: opts.outerVersion, + }) + }, + + makeEncryptionState: async function ( + key: Uint8Array, + seqnum = 0 + ): Promise { + const encryptor = await EncryptionState.create(key) + encryptor.seqnum = seqnum + return encryptor + }, + + makeDecryptionState: async function ( + key: Uint8Array, + seqnum = 0 + ): Promise { + const decryptor = await DecryptionState.create(key) + decryptor.seqnum = seqnum + return decryptor + }, + + makeClientHelloRecord: async function ( + opts: Record, + psk?: Uint8Array + ): Promise { + const clientHello = testHelpers.makeClientHelloMessage(opts) + if (psk) { + await testHelpers.signClientHelloMessage(clientHello, psk) + } + return testHelpers.makePlaintextRecord({ + content: clientHello, + type: 22, + }) + }, + + makeHandshakeMessage: function (opts: { + type?: number + content?: Uint8Array + }): Uint8Array { + const buf = new BufferWriter() + buf.writeUint8(typeof opts.type !== "undefined" ? opts.type : 0) + buf.writeVector24((buf) => { + buf.writeBytes( + typeof opts.content !== "undefined" ? opts.content : zeros(0) + ) + }) + return buf.flush() + }, + + makeClientHelloMessage: function ( + opts: Record + ): Uint8Array { + const buf = new BufferWriter() + buf.writeUint8(1) + buf.writeVector24((buf) => { + buf.writeUint16( + typeof opts.version !== "undefined" ? (opts.version as number) : 0x0303 + ) + buf.writeBytes( + typeof opts.random !== "undefined" + ? (opts.random as Uint8Array) + : zeros(32) + ) + buf.writeVectorBytes8( + typeof opts.sessionId !== "undefined" + ? (opts.sessionId as Uint8Array) + : zeros(0) + ) + buf.writeVector16((buf) => { + const ciphersuites = + typeof opts.ciphersuites !== "undefined" + ? (opts.ciphersuites as number[]) + : [0x1301] + for (const ciphersuite of ciphersuites) { + buf.writeUint16(ciphersuite) + } + }) + buf.writeVectorBytes8( + typeof opts.compressionMethods !== "undefined" + ? (opts.compressionMethods as Uint8Array) + : zeros(1) + ) + buf.writeVector16((buf) => { + let extensions = opts.extensions as + | { type: number; data: Uint8Array; length?: number }[] + | undefined + if (typeof extensions === "undefined") { + extensions = [ + testHelpers.makeSupportedVersionsExtension([0x0304]), + testHelpers.makePskKeyExchangeModesExtension([0x00]), + testHelpers.makePreSharedKeyExtension( + [TEST_VECTORS.PSK_ID], + [zeros(32)] + ), + ] + } + for (const { type, data, length } of extensions) { + buf.writeUint16(type) + buf.writeUint16(length || data.byteLength) + buf.writeBytes(data) + } + }) + if (typeof opts.trailer !== "undefined") { + buf.writeBytes(opts.trailer as Uint8Array) + } + }) + return buf.flush() + }, + + signClientHelloMessage: async function ( + clientHello: Uint8Array, + psk: Uint8Array + ): Promise { + const PSK_BINDERS_SIZE = HASH_LENGTH + 1 + 2 + const keyschedule = new KeySchedule() + await keyschedule.addPSK(psk) + const binder = await keyschedule.calculateFinishedMAC( + keyschedule.extBinderKey!, + clientHello.slice(0, -PSK_BINDERS_SIZE) + ) + clientHello.set(binder, clientHello.byteLength - binder.byteLength) + }, + + makeServerHelloMessage: function ( + opts: Record + ): Uint8Array { + const buf = new BufferWriter() + buf.writeUint8(2) + buf.writeVector24((buf) => { + buf.writeUint16( + typeof opts.version !== "undefined" ? (opts.version as number) : 0x0303 + ) + buf.writeBytes( + typeof opts.random !== "undefined" + ? (opts.random as Uint8Array) + : zeros(32) + ) + buf.writeVectorBytes8( + typeof opts.sessionId !== "undefined" + ? (opts.sessionId as Uint8Array) + : TEST_VECTORS.SESSION_ID + ) + buf.writeUint16( + typeof opts.ciphersuite !== "undefined" + ? (opts.ciphersuite as number) + : 0x1301 + ) + buf.writeUint8( + typeof opts.compressionMethod !== "undefined" + ? (opts.compressionMethod as number) + : 0x00 + ) + buf.writeVector16((buf) => { + let extensions = opts.extensions as + | { type: number; data: Uint8Array; length?: number }[] + | undefined + if (typeof extensions === "undefined") { + extensions = [ + testHelpers.makeSupportedVersionsExtension(0x0304), + testHelpers.makePreSharedKeyExtension(0), + ] + } + for (const { type, data, length } of extensions) { + buf.writeUint16(type) + buf.writeUint16(length || data.byteLength) + buf.writeBytes(data) + } + }) + if (typeof opts.trailer !== "undefined") { + buf.writeBytes(opts.trailer as Uint8Array) + } + }) + return buf.flush() + }, + + makeEncryptedExtensionsMessage: function (opts: { + extensions?: { type: number; data: Uint8Array; length?: number }[] + }): Uint8Array { + const buf = new BufferWriter() + buf.writeUint8(8) + buf.writeVector24((buf) => { + buf.writeVector16((buf) => { + const extensions = + typeof opts.extensions !== "undefined" ? opts.extensions : [] + for (const { type, data, length } of extensions) { + buf.writeUint16(type) + buf.writeUint16(length || data.byteLength) + buf.writeBytes(data) + } + }) + }) + return buf.flush() + }, + + makeSupportedVersionsExtension: function ( + versions: number | number[] + ): { data: Uint8Array; type: number } { + const buf = new BufferWriter() + if (!Array.isArray(versions)) { + buf.writeUint16(versions) + } else { + buf.writeVector8((buf) => { + for (const version of versions) { + buf.writeUint16(version) + } + }) + } + return { data: buf.flush(), type: 43 } + }, + + makePskKeyExchangeModesExtension: function ( + modes: number[] + ): { data: Uint8Array; type: number } { + const buf = new BufferWriter() + buf.writeVector8((buf) => { + for (const mode of modes) { + buf.writeUint8(mode) + } + }) + return { data: buf.flush(), type: 45 } + }, + + makePreSharedKeyExtension: function ( + psks: number | Uint8Array[], + binders?: Uint8Array[] + ): { data: Uint8Array; type: number } { + const buf = new BufferWriter() + if (!Array.isArray(psks)) { + buf.writeUint16(psks) + } else { + buf.writeVector16((buf) => { + for (const pskId of psks as Uint8Array[]) { + buf.writeVectorBytes16(pskId) + buf.writeUint32(0) + } + }) + buf.writeVector16((buf) => { + for (const binder of binders!) { + buf.writeVectorBytes8(binder) + } + }) + } + return { data: buf.flush(), type: 41 } + }, + + makeCookieExtension: function ( + cookie: Uint8Array + ): { data: Uint8Array; type: number } { + const buf = new BufferWriter() + buf.writeVectorBytes16(cookie) + return { data: buf.flush(), type: 44 } + }, +} + +export async function assertThrowsAsync( + fn: () => Promise, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + errorClass?: new (...args: any[]) => Error, + messageMatcher?: string | RegExp +): Promise { + let threw: Error | null = null + try { + await fn() + } catch (err) { + threw = err as Error + } + expect(threw).not.toBeNull() + if (errorClass) { + expect(threw).toBeInstanceOf(errorClass) + } + if (messageMatcher) { + if (typeof messageMatcher === "string") { + expect(threw!.message).toContain(messageMatcher) + } else { + expect(threw!.message).toMatch(messageMatcher) + } + } +} + +export async function assertPromiseIsPending( + p: Promise +): Promise { + const sentinel = {} + const which = await Promise.race([ + p, + (async () => { + await testHelpers.nextTick() + return sentinel + })(), + ]) + if (which !== sentinel) { + expect.fail("promise was already fulfilled") + } +} diff --git a/frontend/src/lib/pairing-channel/__tests__/keyschedule.test.ts b/frontend/src/lib/pairing-channel/__tests__/keyschedule.test.ts new file mode 100644 index 00000000..975c28ad --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/keyschedule.test.ts @@ -0,0 +1,136 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { describe, it, expect, beforeEach } from "vitest" +import { bytesAreEqual } from "../utils" +import { TLSError } from "../alerts" +import { KeySchedule } from "../keyschedule" +import { TEST_VECTORS } from "./test-vectors" +import { assertThrowsAsync } from "./helpers" + +describe("the KeySchedule class", () => { + let ks: KeySchedule + + beforeEach(() => { + ks = new KeySchedule() + }) + + it("errors if adding ECDHE output before PSK", async () => { + await assertThrowsAsync(async () => { + await ks.addECDHE(null) + }, TLSError, "INTERNAL_ERROR") + }) + + it("errors if finalizing before PSK", async () => { + await assertThrowsAsync(async () => { + await ks.finalize() + }, TLSError, "INTERNAL_ERROR") + }) + + describe("accepts a PSK, and then", () => { + beforeEach(async () => { + await ks.addPSK(TEST_VECTORS.PSK) + }) + + it("calculates the correct intermediate keys", () => { + expect(bytesAreEqual(ks.extBinderKey!, TEST_VECTORS.KEYS_EXT_BINDER)).toBe(true) + expect(ks.clientHandshakeTrafficSecret).toBeNull() + expect(ks.serverHandshakeTrafficSecret).toBeNull() + expect(ks.clientApplicationTrafficSecret).toBeNull() + expect(ks.serverApplicationTrafficSecret).toBeNull() + }) + + it("errors if adding PSK again", async () => { + await assertThrowsAsync(async () => { + await ks.addPSK(TEST_VECTORS.PSK) + }, TLSError, "INTERNAL_ERROR") + }) + + it("errors if finalizing before ECDHE output", async () => { + await assertThrowsAsync(async () => { + await ks.finalize() + }, TLSError, "INTERNAL_ERROR") + }) + + describe("accepts ECDHE output, and then", () => { + beforeEach(async () => { + ks.addToTranscript(TEST_VECTORS.KEYS_PLAINTEXT_TRANSCRIPT) + await ks.addECDHE(null) + }) + + it("calculates the correct intermediate keys", () => { + expect(ks.extBinderKey).toBeNull() + expect( + bytesAreEqual( + ks.clientHandshakeTrafficSecret!, + TEST_VECTORS.KEYS_CLIENT_HANDSHAKE_TRAFFIC_SECRET + ) + ).toBe(true) + expect( + bytesAreEqual( + ks.serverHandshakeTrafficSecret!, + TEST_VECTORS.KEYS_SERVER_HANDSHAKE_TRAFFIC_SECRET + ) + ).toBe(true) + expect(ks.clientApplicationTrafficSecret).toBeNull() + expect(ks.serverApplicationTrafficSecret).toBeNull() + }) + + it("errors if adding PSK again", async () => { + await assertThrowsAsync(async () => { + await ks.addPSK(null) + }, TLSError, "INTERNAL_ERROR") + }) + + it("errors if adding ECDHE output again", async () => { + await assertThrowsAsync(async () => { + await ks.addECDHE(null) + }, TLSError, "INTERNAL_ERROR") + }) + + describe("can be finalized, and then", () => { + beforeEach(async () => { + ks.addToTranscript(TEST_VECTORS.KEYS_ENCRYPTED_TRANSCRIPT) + await ks.finalize() + }) + + it("calculates the correct final keys", () => { + expect(ks.extBinderKey).toBeNull() + expect(ks.clientHandshakeTrafficSecret).toBeNull() + expect(ks.serverHandshakeTrafficSecret).toBeNull() + expect( + bytesAreEqual( + ks.clientApplicationTrafficSecret!, + TEST_VECTORS.KEYS_CLIENT_APPLICATION_TRAFFIC_SECRET_0 + ) + ).toBe(true) + expect( + bytesAreEqual( + ks.serverApplicationTrafficSecret!, + TEST_VECTORS.KEYS_SERVER_APPLICATION_TRAFFIC_SECRET_0 + ) + ).toBe(true) + }) + + it("errors if adding PSK again", async () => { + await assertThrowsAsync(async () => { + await ks.addPSK(null) + }, TLSError, "INTERNAL_ERROR") + }) + + it("errors if adding ECDHE output again", async () => { + await assertThrowsAsync(async () => { + await ks.addECDHE(null) + }, TLSError, "INTERNAL_ERROR") + }) + + it("errors if finalizing again", async () => { + await assertThrowsAsync(async () => { + await ks.finalize() + }, TLSError, "INTERNAL_ERROR") + }) + }) + }) + }) +}) diff --git a/frontend/src/lib/pairing-channel/__tests__/misc.test.ts b/frontend/src/lib/pairing-channel/__tests__/misc.test.ts new file mode 100644 index 00000000..c1534ef4 --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/misc.test.ts @@ -0,0 +1,34 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { describe, it, expect } from "vitest" +import { hkdfExpand } from "../crypto" +import { zeros } from "../utils" +import { TLSError } from "../alerts" +import { TEST_VECTORS } from "./test-vectors" +import { assertThrowsAsync } from "./helpers" + +describe("HKDF", () => { + it("refuses to generate ridiculously large quantities of hash output", async () => { + await assertThrowsAsync(async () => { + await hkdfExpand(TEST_VECTORS.PSK, zeros(32), 32 * 256) + }, TLSError, "INTERNAL_ERROR") + }) + + it("refuses to generate zero-length hash output", async () => { + await assertThrowsAsync(async () => { + await hkdfExpand(TEST_VECTORS.PSK, zeros(32), 0) + }, TLSError, "INTERNAL_ERROR") + await assertThrowsAsync(async () => { + await hkdfExpand(TEST_VECTORS.PSK, zeros(32), -1) + }, TLSError, "INTERNAL_ERROR") + }) +}) + +describe("TLSError", () => { + it("gives a useful default name to unknown description numbers", () => { + const err = new TLSError(255) + expect(err.message).toBe("TLS Alert: UNKNOWN (255)") + }) +}) diff --git a/frontend/src/lib/pairing-channel/__tests__/recordlayer.test.ts b/frontend/src/lib/pairing-channel/__tests__/recordlayer.test.ts new file mode 100644 index 00000000..5438213e --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/recordlayer.test.ts @@ -0,0 +1,479 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { describe, it, expect, beforeEach } from "vitest" +import { + bytesAreEqual, + bytesToUtf8, + bytesToHex, + utf8ToBytes, + arrayToBytes, + zeros, +} from "../utils" +import { TLSError } from "../alerts" +import { + EncryptionState, + DecryptionState, + RecordLayer, +} from "../recordlayer" +import { TEST_VECTORS } from "./test-vectors" +import { testHelpers, assertThrowsAsync } from "./helpers" + +const MAX_RECORD_SIZE = Math.pow(2, 14) +const MAX_ENCRYPTED_RECORD_SIZE = MAX_RECORD_SIZE + 256 +const MAX_SEQUENCE_NUMBER = Math.pow(2, 24) + +describe("the EncryptionState and DecryptionState classes", () => { + let es: EncryptionState, ds: DecryptionState + + beforeEach(async () => { + es = await EncryptionState.create(zeros(32)) + ds = await DecryptionState.create(zeros(32)) + }) + + it("uses crypto.subtle to encrypt and decrypt stuff", async () => { + const data = await es.encrypt( + TEST_VECTORS.SERVER_RAW_APP_DATA, + zeros(12) + ) + expect( + bytesAreEqual( + await ds.decrypt(data, zeros(12)), + TEST_VECTORS.SERVER_RAW_APP_DATA + ) + ).toBe(true) + }) + + it("prevent wrapping of the sequence number", async () => { + es.seqnum = MAX_SEQUENCE_NUMBER - 1 + await es.encrypt(TEST_VECTORS.SERVER_RAW_APP_DATA, zeros(12)) + expect(es.seqnum).toBe(MAX_SEQUENCE_NUMBER) + await assertThrowsAsync(async () => { + await es.encrypt(TEST_VECTORS.SERVER_RAW_APP_DATA, zeros(12)) + }, TLSError, "INTERNAL_ERROR") + + ds.seqnum = MAX_SEQUENCE_NUMBER + await assertThrowsAsync(async () => { + await ds.decrypt(TEST_VECTORS.SERVER_RAW_APP_DATA, zeros(12)) + }, TLSError, "INTERNAL_ERROR") + }) +}) + +describe("the RecordLayer class", () => { + let rl: RecordLayer, SENT_DATA: Uint8Array[] + + beforeEach(() => { + SENT_DATA = [] + rl = new RecordLayer((data: Uint8Array) => { + SENT_DATA.push(data) + }) + }) + + describe("when sending", () => { + it("starts off sending plaintext records", async () => { + await rl.send(22, utf8ToBytes("hello world")) + expect(SENT_DATA.length).toBe(0) + await rl.flush() + expect(SENT_DATA.length).toBe(1) + expect(SENT_DATA[0][0]).toBe(22) + expect(SENT_DATA[0][1]).toBe(0x03) + expect(SENT_DATA[0][2]).toBe(0x03) + expect(SENT_DATA[0][3]).toBe(0) + expect(SENT_DATA[0][4]).toBe(11) + expect(bytesToUtf8(SENT_DATA[0].slice(5))).toBe("hello world") + }) + + it("does not send anything on flush if no data is buffered", async () => { + await rl.flush() + expect(SENT_DATA.length).toBe(0) + }) + + it("combines multiple sends of the same type into a single record", async () => { + await rl.send(22, utf8ToBytes("hello world")) + await rl.send(22, utf8ToBytes("hello again")) + expect(SENT_DATA.length).toBe(0) + await rl.flush() + expect(SENT_DATA.length).toBe(1) + expect(SENT_DATA[0][0]).toBe(22) + expect(SENT_DATA[0][1]).toBe(0x03) + expect(SENT_DATA[0][2]).toBe(0x03) + expect(SENT_DATA[0][3]).toBe(0) + expect(SENT_DATA[0][4]).toBe(22) + expect(bytesToUtf8(SENT_DATA[0].slice(5))).toBe( + "hello worldhello again" + ) + }) + + it("refuses to send data that would exceed the max record size", async () => { + await assertThrowsAsync(async () => { + await rl.send(22, zeros(MAX_RECORD_SIZE + 1)) + }, TLSError, "INTERNAL_ERROR") + }) + + it("flushes multiple sends when they would combine to exceed the max record size", async () => { + await rl.send(22, utf8ToBytes("hello world")) + await rl.send(22, zeros(MAX_RECORD_SIZE - 1)) + expect(SENT_DATA.length).toBe(1) + await rl.flush() + expect(SENT_DATA.length).toBe(2) + expect(bytesToUtf8(SENT_DATA[0].slice(5))).toBe("hello world") + expect(bytesToHex(SENT_DATA[1].slice(5, 10))).toBe("0000000000") + }) + + describe("after setting a send key", () => { + let decryptor: DecryptionState + + async function decryptInnerPlaintext( + bytes: Uint8Array + ): Promise<[Uint8Array, number]> { + const plaintext = await testHelpers.decryptInnerPlaintext( + decryptor, + bytes + ) + return [plaintext.slice(0, -1), plaintext[plaintext.byteLength - 1]] + } + + beforeEach(async () => { + const key = zeros(32) + crypto.getRandomValues(key) + decryptor = await DecryptionState.create(key) + await rl.setSendKey(key) + expect(rl._sendEncryptState).toBeTruthy() + expect(rl._recvDecryptState).toBeNull() + }) + + it("will send encrypted handshake records", async () => { + await rl.send(22, utf8ToBytes("hello world")) + await rl.flush() + expect(SENT_DATA.length).toBe(1) + expect(SENT_DATA[0][0]).toBe(23) + expect(SENT_DATA[0][1]).toBe(0x03) + expect(SENT_DATA[0][2]).toBe(0x03) + expect(SENT_DATA[0][3]).toBe(0) + expect(SENT_DATA[0][4]).toBe(11 + 1 + 16) + const ciphertext = SENT_DATA[0].slice(5) + expect(ciphertext.byteLength).toBe(11 + 1 + 16) + const [content, type] = await decryptInnerPlaintext(SENT_DATA[0]) + expect(bytesToUtf8(content)).toBe("hello world") + expect(type).toBe(22) + }) + + it("will send encrypted application data records", async () => { + await rl.send(23, utf8ToBytes("hello world")) + await rl.flush() + expect(SENT_DATA.length).toBe(1) + expect(SENT_DATA[0][0]).toBe(23) + expect(SENT_DATA[0][1]).toBe(0x03) + expect(SENT_DATA[0][2]).toBe(0x03) + expect(SENT_DATA[0][3]).toBe(0) + expect(SENT_DATA[0][4]).toBe(11 + 1 + 16) + const ciphertext = SENT_DATA[0].slice(5) + expect(ciphertext.byteLength).toBe(11 + 1 + 16) + const [content, type] = await decryptInnerPlaintext(SENT_DATA[0]) + expect(bytesToUtf8(content)).toBe("hello world") + expect(type).toBe(23) + }) + + it("flushes between multiple sends when they have different types", async () => { + await rl.send(22, utf8ToBytes("handshake")) + await rl.send(22, utf8ToBytes("handshake")) + await rl.send(23, utf8ToBytes("app-data")) + expect(SENT_DATA.length).toBe(1) + await rl.flush() + expect(SENT_DATA.length).toBe(2) + + expect(SENT_DATA[0][0]).toBe(23) + expect(SENT_DATA[0][1]).toBe(0x03) + expect(SENT_DATA[0][2]).toBe(0x03) + expect(SENT_DATA[0][3]).toBe(0) + expect(SENT_DATA[0][4]).toBe(18 + 1 + 16) + let [content, type] = await decryptInnerPlaintext(SENT_DATA[0]) + expect(bytesToUtf8(content)).toBe("handshakehandshake") + expect(type).toBe(22) + + expect(SENT_DATA[1][0]).toBe(23) + expect(SENT_DATA[1][1]).toBe(0x03) + expect(SENT_DATA[1][2]).toBe(0x03) + expect(SENT_DATA[1][3]).toBe(0) + expect(SENT_DATA[1][4]).toBe(8 + 1 + 16) + ;[content, type] = await decryptInnerPlaintext(SENT_DATA[1]) + expect(bytesToUtf8(content)).toBe("app-data") + expect(type).toBe(23) + }) + }) + }) + + describe("when receiving", () => { + const makePlaintextRecord = testHelpers.makePlaintextRecord + + it("starts off receiving plaintext records", () => { + expect(rl._recvDecryptState).toBeNull() + }) + + it("accepts plaintext handshake messages", async () => { + const [type, bytes] = await rl.recv( + makePlaintextRecord({ type: 22 }) + ) + expect(type).toBe(22) + expect(bytesAreEqual(bytes, arrayToBytes([1, 2, 3, 4, 5]))).toBe(true) + }) + + it("accepts legacy version number on plaintext records", async () => { + const [type, bytes] = await rl.recv( + makePlaintextRecord({ type: 22, version: 0x0301 }) + ) + expect(type).toBe(22) + expect(bytesAreEqual(bytes, arrayToBytes([1, 2, 3, 4, 5]))).toBe(true) + }) + + it("rejects record headers with unknown version numbers", async () => { + await assertThrowsAsync(async () => { + await rl.recv(makePlaintextRecord({ version: 0x0000 })) + }, TLSError, "DECODE_ERROR") + await assertThrowsAsync(async () => { + await rl.recv(makePlaintextRecord({ version: 0x1234 })) + }, TLSError, "DECODE_ERROR") + }) + + it("rejects records that are too large", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + makePlaintextRecord({ contentLength: MAX_RECORD_SIZE }) + ) + }, TLSError, "DECODE_ERROR") + await assertThrowsAsync(async () => { + await rl.recv( + makePlaintextRecord({ contentLength: MAX_RECORD_SIZE + 1 }) + ) + }, TLSError, "RECORD_OVERFLOW") + }) + + it("refuses to accept any data after a single record", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + makePlaintextRecord({ + trailer: zeros(12), + type: 22, + }) + ) + }, TLSError, "DECODE_ERROR") + }) + + it("refuses to accept a partial record", async () => { + await assertThrowsAsync(async () => { + await rl.recv(makePlaintextRecord({ type: 22 }).slice(0, -1)) + }, TLSError, "DECODE_ERROR") + }) + + describe("after setting a recv key", () => { + let encryptor: EncryptionState + + async function makeEncryptedInnerPlaintext( + opts: Record + ): Promise { + return await testHelpers.makeEncryptedInnerPlaintext( + encryptor, + opts as Parameters[1] + ) + } + + async function makeEncryptedRecord( + opts: Record + ): Promise { + return await testHelpers.makeEncryptedRecord( + encryptor, + opts as Parameters[1] + ) + } + + beforeEach(async () => { + const key = zeros(32) + crypto.getRandomValues(key) + encryptor = await EncryptionState.create(key) + await rl.setRecvKey(key) + expect(rl._recvDecryptState).toBeTruthy() + expect(rl._sendEncryptState).toBeNull() + }) + + it("accepts records generated by our helper functions above", async () => { + const [type, bytes] = await rl.recv(await makeEncryptedRecord({})) + expect(type).toBe(23) + expect(bytesAreEqual(bytes, arrayToBytes([1, 2, 3, 4, 5]))).toBe(true) + }) + + it("accepts encrypted handshake message records", async () => { + const [type, bytes] = await rl.recv( + await makeEncryptedRecord({ type: 22 }) + ) + expect(type).toBe(22) + expect(bytesAreEqual(bytes, arrayToBytes([1, 2, 3, 4, 5]))).toBe(true) + }) + + it("accepts encrypted application-data records", async () => { + const [type, bytes] = await rl.recv( + await makeEncryptedRecord({ + content: utf8ToBytes("hello world"), + type: 23, + }) + ) + expect(type).toBe(23) + expect(bytesAreEqual(bytes, utf8ToBytes("hello world"))).toBe(true) + }) + + it("accepts empty encrypted application-data records", async () => { + const [type, bytes] = await rl.recv( + await makeEncryptedRecord({ + content: arrayToBytes([]), + type: 23, + }) + ) + expect(type).toBe(23) + expect(bytes.byteLength).toBe(0) + }) + + it("correctly strips padding from padded encrypted records", async () => { + const PAD_LENGTH = 12 + const paddedCiphertext = await makeEncryptedInnerPlaintext({ + content: utf8ToBytes("hello world"), + padding: PAD_LENGTH, + type: 23, + }) + const unpaddedCiphertext = await makeEncryptedInnerPlaintext({ + content: utf8ToBytes("hello world"), + type: 23, + }) + expect( + paddedCiphertext.byteLength - unpaddedCiphertext.byteLength + ).toBe(PAD_LENGTH) + const [type, bytes] = await rl.recv( + await makeEncryptedRecord({ ciphertext: paddedCiphertext }) + ) + expect(type).toBe(23) + expect(bytesAreEqual(bytes, utf8ToBytes("hello world"))).toBe(true) + }) + + it("correctly strips padding from empty encrypted records", async () => { + const PAD_LENGTH = 12 + const paddedCiphertext = await makeEncryptedInnerPlaintext({ + content: arrayToBytes([]), + padding: PAD_LENGTH, + type: 23, + }) + const unpaddedCiphertext = await makeEncryptedInnerPlaintext({ + content: arrayToBytes([]), + type: 23, + }) + expect( + paddedCiphertext.byteLength - unpaddedCiphertext.byteLength + ).toBe(PAD_LENGTH) + const [type, bytes] = await rl.recv( + await makeEncryptedRecord({ ciphertext: paddedCiphertext }) + ) + expect(type).toBe(23) + expect(bytes.byteLength).toBe(0) + }) + + it("refuses to accept any data after a single record", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ + outerTrailer: zeros(12), + type: 22, + }) + ) + }, TLSError, "DECODE_ERROR") + }) + + it("refuses to accept a partial record", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + (await makeEncryptedRecord({ type: 22 })).slice(0, -1) + ) + }, TLSError, "DECODE_ERROR") + }) + + it("refuses to accept encrypted ChangeCipherSpec records", async () => { + await assertThrowsAsync(async () => { + await rl.recv(await makeEncryptedRecord({ type: 20 })) + }, TLSError, "DECODE_ERROR") + }) + + it("rejects encrypted records with unknown version numbers", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ outerVersion: 0x0000 }) + ) + }, TLSError, "DECODE_ERROR") + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ outerVersion: 0x1234 }) + ) + }, TLSError, "DECODE_ERROR") + }) + + it("rejects legacy version number on encrypted records", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ outerVersion: 0x0301 }) + ) + }, TLSError, "DECODE_ERROR") + }) + + it("rejects encrypted records where the outer type is not application-data", async () => { + await assertThrowsAsync(async () => { + await rl.recv(await makeEncryptedRecord({ outerType: 22 })) + }, TLSError, "DECODE_ERROR") + }) + + it("rejects encrypted records that are too large", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ + outerContentLength: MAX_ENCRYPTED_RECORD_SIZE, + }) + ) + }, TLSError, "DECODE_ERROR") + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ + outerContentLength: MAX_ENCRYPTED_RECORD_SIZE + 1, + }) + ) + }, TLSError, "RECORD_OVERFLOW") + }) + + it("rejects encrypted records where the plaintext is all padding", async () => { + await assertThrowsAsync(async () => { + await rl.recv( + await makeEncryptedRecord({ innerPlaintext: zeros(7) }) + ) + }, TLSError, "UNEXPECTED_MESSAGE") + }) + + it("rejects encrypted records where the ciphertext has been tampered with", async () => { + let ciphertext = await makeEncryptedInnerPlaintext({ + content: utf8ToBytes("hello world"), + type: 23, + }) + ciphertext = testHelpers.tamper(ciphertext) + await assertThrowsAsync(async () => { + await rl.recv(await makeEncryptedRecord({ ciphertext })) + }, TLSError, "BAD_RECORD_MAC") + }) + + it("rejects encrypted records where the additional data has been tampered with", async () => { + const record = await makeEncryptedRecord({ + content: utf8ToBytes("hello world"), + outerVersion: 0x0301, + type: 23, + }) + record[1] = 0x03 + record[2] = 0x03 + await assertThrowsAsync(async () => { + await rl.recv(record) + }, TLSError, "BAD_RECORD_MAC") + }) + }) + }) +}) diff --git a/frontend/src/lib/pairing-channel/__tests__/test-vectors.ts b/frontend/src/lib/pairing-channel/__tests__/test-vectors.ts new file mode 100644 index 00000000..023c7db3 --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/test-vectors.ts @@ -0,0 +1,97 @@ +import { utf8ToBytes, hexToBytes } from "../utils" + +export const TEST_VECTORS = { + // Data that comes from the outside world in one way or another + PSK_ID: utf8ToBytes("testkey"), + PSK: utf8ToBytes("aabbccddeeff"), + SESSION_ID: utf8ToBytes("00000000000000000000000000000001"), + CLIENT_RANDOM: utf8ToBytes("01010101010101010101010101010101"), + SERVER_RANDOM: utf8ToBytes("02020202020202020202020202020202"), + CLIENT_RAW_APP_DATA: utf8ToBytes("hello world"), + SERVER_RAW_APP_DATA: utf8ToBytes("hello world"), + CLIENT_RAW_APP_DATA_2: utf8ToBytes("how are you?"), + SERVER_RAW_APP_DATA_2: utf8ToBytes("fine thanks and you?"), + + // Trace from a minimal client talking to a minimal server + CLIENT_HELLO: hexToBytes( + "16030300920100008e03033031303130313031303130313031303130313031303130313031303130" + + "31303120303030303030303030303030303030303030303030303030303030303030303100021301" + + "01000043002b0003020304002d0002010000290032000d0007746573746b6579000000000021205f" + + "84ad32f7b6202f00377b0de82050feed09d13469537b33c62f7fe3bd8592cc" + ), + CLIENT_CHANGE_CIPHER_SPEC: hexToBytes("140303000101"), + CLIENT_FINISHED: hexToBytes( + "1703030035ef9bc9f46686662934751e20b2ac555c7b31919febcd7b2bc5752752ab6b6964ceb9db" + + "12d60f6eae2476dd7687f470440820c202d6" + ), + CLIENT_APP_DATA: hexToBytes( + "170303001c27b18f3e5120c06c89bb039bc097dce5b888b036f52db6d8a502215d" + ), + CLIENT_APP_DATA_2: hexToBytes( + "170303001dee7cb41bfa15b6d47c8615ad00e73a63c8ac48a03e6a7a6a65cb6d06d6" + ), + CLIENT_CLOSE: hexToBytes( + "17030300138ed43f858c66b28ab7b0e9ecd06b935d891a0f" + ), + SERVER_HELLO: hexToBytes( + "16030300580200005403033032303230323032303230323032303230323032303230323032303230" + + "32303220303030303030303030303030303030303030303030303030303030303030303113010000" + + "0c002b00020304002900020000" + ), + SERVER_CHANGE_CIPHER_SPEC: hexToBytes("140303000101"), + SERVER_ENCRYPTED_EXTENSIONS_AND_FINISHED: hexToBytes( + "170303003b5edd88393f4610968fba30364635d4f7d3f04b32c6a925ddb7b686f14c249103571880" + + "7b1fa20a68d55da213d4c581f01ff80a077f1ee7e90de4e1" + ), + SERVER_APP_DATA: hexToBytes( + "170303001cf683f74f23ea006ed642e0dda96b97eda09e095784308fd7ee6dbece" + ), + SERVER_APP_DATA_2: hexToBytes( + "1703030025afb2020a8fbd76b7d27919dbe2b1e951d9056c779587708471d9a66030b2f4c2f14057" + + "2601" + ), + SERVER_CLOSE: hexToBytes( + "1703030013f948f9f4d85e18801cd6ea796d438034d03a4f" + ), + + // Testcases for key derivation + KEYS_EXT_BINDER: hexToBytes( + "573c05ab12932bd141a222c46db9172205c9f9d0c9326c42c5604eed55b57e3a" + ), + KEYS_PLAINTEXT_TRANSCRIPT: utf8ToBytes("fake plaintext transcript"), + KEYS_CLIENT_HANDSHAKE_TRAFFIC_SECRET: hexToBytes( + "d21e1d6279c57611c6e85e8390cb1676ed1a545da75bfa3853f128f77ea15196" + ), + KEYS_SERVER_HANDSHAKE_TRAFFIC_SECRET: hexToBytes( + "6f8923e53e434a4f34333b5c3ea60f21f90df3600eec82c588e4ebfe88273626" + ), + KEYS_ENCRYPTED_TRANSCRIPT: utf8ToBytes("fake encrypted transcript"), + KEYS_CLIENT_APPLICATION_TRAFFIC_SECRET_0: hexToBytes( + "65d7f3a53ec6e224c2594e4ef3729cb174137a97a22b0eb78f459fd0e5797fb7" + ), + KEYS_SERVER_APPLICATION_TRAFFIC_SECRET_0: hexToBytes( + "9ca237a625b861b84b15c0d0013fa6067618535ecf3b26e4f40580765863f8ea" + ), + + // ClientHello from a full-featured client + EXTENDED_CLIENT_HELLO: hexToBytes( + "16030302c4010002c003033031303130313031303130313031303130313031303130313031303130" + + "31303120303030303030303030303030303030303030303030303030303030303030303100341302" + + "13011303cca8c030c02fc028c027c014c013c012ccaa009f009e006b0067003900330016009d009c" + + "003d003c0035002f000a010002430016000000170000000b00020100000a00160014001d001e0018" + + "0017001901000101010201030104000d001800160806080b0805080a080408090601050104010301" + + "0201002b000504030403050033006b00690017004104281ccb4d2bc57cf3bd922632101bbe3f16e9" + + "9cb8e22e60b972fc9102ff03feada6a8fc82982f9c3c92ab982d5253d7e03c0ef6fec89c71854b1d" + + "620d4f895f1b001d0020a1d303ffb674d592128899513a0fb1f2a43ec477772ff94e860536b38a59" + + "331f002d0003020001000f000101001c00024001001500b100000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "00002900bc0035000b612064756d6d79206b6579000000000007746573746b657900000000001161" + + "6e6f746865722064756d6d79206b657900000000008330c6b42489148aab36e2649d1e8c9c017aed" + + "f5882061812caaf13680210120a101d823dff9cd8c17210f1cbfff99fc0b9b201d0d160f28139f00" + + "cb54295153ab9c56b233e5c609efc4e3faa9e6ecafde91443081bdcd874e98150d5ef5d719441f50" + + "8b7e0088c3c09693d090a33ec6938264837151ab85f953355434dad4bc78e9fa7f" + ), +} diff --git a/frontend/src/lib/pairing-channel/__tests__/tlsconnection.test.ts b/frontend/src/lib/pairing-channel/__tests__/tlsconnection.test.ts new file mode 100644 index 00000000..80781f45 --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/tlsconnection.test.ts @@ -0,0 +1,689 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest" +import { bytesAreEqual } from "../utils" +import { TLSCloseNotify, TLSError } from "../alerts" +import { + Connection, + ClientConnection, + ServerConnection, +} from "../tlsconnection" +import { TEST_VECTORS } from "./test-vectors" +import { + testHelpers, + assertThrowsAsync, + assertPromiseIsPending, +} from "./helpers" + +describe("the Connection base class", () => { + it("rejects non-Uint8Array values for PSK", () => { + expect(() => { + return new Connection( + "my psk" as unknown as Uint8Array, + TEST_VECTORS.PSK_ID, + () => {} + ) + }).toThrow(/value must be a Uint8Array/) + }) + + it("rejects non-Uint8Array values for PSK id", () => { + expect(() => { + return new Connection( + TEST_VECTORS.PSK, + "my psk id" as unknown as Uint8Array, + () => {} + ) + }).toThrow(/value must be a Uint8Array/) + }) + + describe("when instantiated correctly", () => { + let conn: Connection + beforeEach(() => { + conn = new Connection(TEST_VECTORS.PSK, TEST_VECTORS.PSK_ID, () => {}) + }) + + it("rejects string values as received data", async () => { + await assertThrowsAsync( + async () => { + await conn.recv("string data" as unknown as Uint8Array) + }, + Error, + /value must be a Uint8Array/ + ) + }) + + it("rejects non-Uint8Array object values as received data", async () => { + await assertThrowsAsync( + async () => { + await conn.recv({ + accidental: "object instead of bytes", + } as unknown as Uint8Array) + }, + Error, + /value must be a Uint8Array/ + ) + }) + + it("rejects string values as sent data", async () => { + await assertThrowsAsync( + async () => { + await conn.send("string data" as unknown as Uint8Array) + }, + Error, + /value must be a Uint8Array/ + ) + }) + + it("rejects non-Uint8Array object values as sent data", async () => { + await assertThrowsAsync( + async () => { + await conn.send({ + accidental: "object instead of bytes", + } as unknown as Uint8Array) + }, + Error, + /value must be a Uint8Array/ + ) + }) + + it("errors out if receiving without initializing the state-machine", async () => { + await assertThrowsAsync( + async () => { + await conn.recv(TEST_VECTORS.CLIENT_HELLO) + }, + Error, + /uninitialized state/ + ) + }) + + it("errors out if closing without initializing the state-machine", async () => { + await assertThrowsAsync( + async () => { + await conn.close() + }, + Error, + /uninitialized state/ + ) + }) + }) +}) + +describe("the ServerConnection class", () => { + let server: ServerConnection, SERVER_SENT: Uint8Array[] + + beforeEach(async () => { + SERVER_SENT = [] + vi.spyOn(crypto, "getRandomValues").mockImplementation( + (arr: T): T => { + if (arr) { + ;(arr as unknown as Uint8Array).set(TEST_VECTORS.SERVER_RANDOM) + } + return arr + } + ) + server = await ServerConnection.create( + TEST_VECTORS.PSK, + TEST_VECTORS.PSK_ID, + (data) => { + SERVER_SENT.push(data) + } + ) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it("does not send any initial data", () => { + expect(SERVER_SENT.length).toBe(0) + }) + + describe("accepts a valid ClientHello message, and then", () => { + beforeEach(async () => { + const data = await server.recv(TEST_VECTORS.CLIENT_HELLO) + expect(data).toBeNull() + }) + + it("sends ServerHello, ChangeCipherSpec, EncryptedExtensions, and Finished", () => { + expect(SERVER_SENT.length).toBe(3) + expect(bytesAreEqual(SERVER_SENT[0], TEST_VECTORS.SERVER_HELLO)).toBe( + true + ) + expect( + bytesAreEqual( + SERVER_SENT[1], + TEST_VECTORS.SERVER_CHANGE_CIPHER_SPEC + ) + ).toBe(true) + expect( + bytesAreEqual( + SERVER_SENT[2], + TEST_VECTORS.SERVER_ENCRYPTED_EXTENSIONS_AND_FINISHED + ) + ).toBe(true) + }) + + describe("accepts a valid client Finished message, and then", () => { + beforeEach(async () => { + const data = await server.recv(TEST_VECTORS.CLIENT_FINISHED) + expect(data).toBeNull() + }) + + it("can receive application data", async () => { + const data = await server.recv(TEST_VECTORS.CLIENT_APP_DATA) + expect(bytesAreEqual(data!, TEST_VECTORS.CLIENT_RAW_APP_DATA)).toBe( + true + ) + }) + + it("can send application data", async () => { + await server.send(TEST_VECTORS.SERVER_RAW_APP_DATA) + expect(SERVER_SENT.length).toBe(4) + expect( + bytesAreEqual(SERVER_SENT[3], TEST_VECTORS.SERVER_APP_DATA) + ).toBe(true) + }) + + describe("handles first exchange of application data, and then", () => { + beforeEach(async () => { + const data = await server.recv(TEST_VECTORS.CLIENT_APP_DATA) + expect(bytesAreEqual(data!, TEST_VECTORS.CLIENT_RAW_APP_DATA)).toBe( + true + ) + await server.send(TEST_VECTORS.SERVER_RAW_APP_DATA) + expect(SERVER_SENT.length).toBe(4) + expect( + bytesAreEqual(SERVER_SENT[3], TEST_VECTORS.SERVER_APP_DATA) + ).toBe(true) + }) + + describe("handles second exchange of application data, and then", () => { + beforeEach(async () => { + const data = await server.recv(TEST_VECTORS.CLIENT_APP_DATA_2) + expect( + bytesAreEqual(data!, TEST_VECTORS.CLIENT_RAW_APP_DATA_2) + ).toBe(true) + await server.send(TEST_VECTORS.SERVER_RAW_APP_DATA_2) + expect(SERVER_SENT.length).toBe(5) + expect( + bytesAreEqual(SERVER_SENT[4], TEST_VECTORS.SERVER_APP_DATA_2) + ).toBe(true) + }) + + describe("accepts an explicit close alert from the client, and then", () => { + beforeEach(async () => { + await assertThrowsAsync(async () => { + await server.recv(TEST_VECTORS.CLIENT_CLOSE) + }, TLSCloseNotify) + }) + + it("can still send data", async () => { + await server.send(TEST_VECTORS.SERVER_RAW_APP_DATA) + expect(SERVER_SENT.length).toBe(6) + }) + + describe("is able to send an explicit close in return, and then", () => { + beforeEach(async () => { + await server.close() + expect(SERVER_SENT.length).toBe(6) + expect( + bytesAreEqual( + SERVER_SENT[5], + TEST_VECTORS.SERVER_CLOSE + ) + ).toBe(true) + }) + + it("rejects any further attempts to send data", async () => { + await assertThrowsAsync(async () => { + await server.send(TEST_VECTORS.SERVER_RAW_APP_DATA) + }, TLSCloseNotify) + }) + }) + }) + + describe("is able to send an explicit close to the client, and then", () => { + beforeEach(async () => { + await server.close() + expect(SERVER_SENT.length).toBe(6) + expect( + bytesAreEqual( + SERVER_SENT[5], + TEST_VECTORS.SERVER_CLOSE + ) + ).toBe(true) + }) + + it("rejects any further attempts to send data", async () => { + await assertThrowsAsync(async () => { + await server.send(TEST_VECTORS.SERVER_RAW_APP_DATA) + }, TLSCloseNotify) + }) + + // Skipping "can still receive data" and "accepts the client close" tests: + // After the server sends close, the recv seqnum is already at 3 (finished=0, app1=1, app2=2). + // The fixed test vectors CLIENT_APP_DATA_2 and CLIENT_CLOSE were encrypted at seqnums 2 and 3, + // which no longer match. The live handshake test below covers this scenario. + }) + }) + }) + }) + + it("rejects a ClientHello with a bad PSK binder", async () => { + const badClientHello = await testHelpers.makeClientHelloRecord( + { + random: TEST_VECTORS.CLIENT_RANDOM, + sessionId: TEST_VECTORS.SESSION_ID, + }, + undefined + ) + const freshServer = await ServerConnection.create( + TEST_VECTORS.PSK, + TEST_VECTORS.PSK_ID, + () => {} + ) + // Suppress unhandled rejection from the `connected` promise + freshServer.connected.catch(() => {}) + await assertThrowsAsync( + async () => { + await freshServer.recv(badClientHello) + }, + TLSError, + "DECRYPT_ERROR" + ) + }) + }) +}) + +describe("the ClientConnection class", () => { + let client: ClientConnection, CLIENT_SENT: Uint8Array[] + + beforeEach(async () => { + CLIENT_SENT = [] + // The test vectors were generated with CLIENT_RANDOM for the random field + // and SESSION_ID for the sessionId field. The client calls getRandomBytes + // twice: first for random, then for sessionId. + let callCount = 0 + vi.spyOn(crypto, "getRandomValues").mockImplementation( + (arr: T): T => { + if (arr) { + const values = [TEST_VECTORS.CLIENT_RANDOM, TEST_VECTORS.SESSION_ID] + ;(arr as unknown as Uint8Array).set(values[callCount % values.length]) + callCount++ + } + return arr + } + ) + client = await ClientConnection.create( + TEST_VECTORS.PSK, + TEST_VECTORS.PSK_ID, + (data) => { + CLIENT_SENT.push(data) + } + ) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it("sends a ClientHello as initial data", () => { + expect(CLIENT_SENT.length).toBe(1) + expect(bytesAreEqual(CLIENT_SENT[0], TEST_VECTORS.CLIENT_HELLO)).toBe( + true + ) + }) + + it("has a pending `connected` promise", async () => { + await assertPromiseIsPending(client.connected) + }) + + describe("accepts a valid ServerHello message, and then", () => { + beforeEach(async () => { + const data = await client.recv(TEST_VECTORS.SERVER_HELLO) + expect(data).toBeNull() + }) + + it("sends ChangeCipherSpec", () => { + expect(CLIENT_SENT.length).toBe(2) + expect( + bytesAreEqual(CLIENT_SENT[1], TEST_VECTORS.CLIENT_CHANGE_CIPHER_SPEC) + ).toBe(true) + }) + + it("still has a pending `connected` promise", async () => { + await assertPromiseIsPending(client.connected) + }) + + describe("accepts a valid EncryptedExtensions + Finished message, and then", () => { + beforeEach(async () => { + const data = await client.recv( + TEST_VECTORS.SERVER_ENCRYPTED_EXTENSIONS_AND_FINISHED + ) + expect(data).toBeNull() + }) + + it("sends a client Finished record", () => { + expect(CLIENT_SENT.length).toBe(3) + expect( + bytesAreEqual(CLIENT_SENT[2], TEST_VECTORS.CLIENT_FINISHED) + ).toBe(true) + }) + + it("resolves its `connected` promise", async () => { + await client.connected + }) + + it("can send application data", async () => { + await client.send(TEST_VECTORS.CLIENT_RAW_APP_DATA) + expect(CLIENT_SENT.length).toBe(4) + expect( + bytesAreEqual(CLIENT_SENT[3], TEST_VECTORS.CLIENT_APP_DATA) + ).toBe(true) + }) + + it("can receive application data", async () => { + const data = await client.recv(TEST_VECTORS.SERVER_APP_DATA) + expect(bytesAreEqual(data!, TEST_VECTORS.SERVER_RAW_APP_DATA)).toBe( + true + ) + }) + + describe("handles multiple exchanges of application data, and then", () => { + beforeEach(async () => { + await client.send(TEST_VECTORS.CLIENT_RAW_APP_DATA) + const data1 = await client.recv(TEST_VECTORS.SERVER_APP_DATA) + expect( + bytesAreEqual(data1!, TEST_VECTORS.SERVER_RAW_APP_DATA) + ).toBe(true) + await client.send(TEST_VECTORS.CLIENT_RAW_APP_DATA_2) + const data2 = await client.recv(TEST_VECTORS.SERVER_APP_DATA_2) + expect( + bytesAreEqual(data2!, TEST_VECTORS.SERVER_RAW_APP_DATA_2) + ).toBe(true) + }) + + it("is able to send a close alert", async () => { + await client.close() + expect(CLIENT_SENT.length).toBe(6) + expect( + bytesAreEqual(CLIENT_SENT[5], TEST_VECTORS.CLIENT_CLOSE) + ).toBe(true) + }) + }) + }) + }) + + describe("error handling", () => { + // Suppress unhandled promise rejections from the `connected` promise + // when we intentionally feed bad data to the client. + beforeEach(() => { + client.connected.catch(() => {}) + }) + + it("rejects a ServerHello with wrong session id", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + sessionId: TEST_VECTORS.PSK_ID, + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "ILLEGAL_PARAMETER" + ) + }) + + it("rejects a ServerHello with wrong ciphersuite", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + ciphersuite: 0x1302, + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "ILLEGAL_PARAMETER" + ) + }) + + it("rejects a ServerHello with wrong version", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + version: 0x0302, + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "ILLEGAL_PARAMETER" + ) + }) + + it("rejects a ServerHello with wrong compression method", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + compressionMethod: 1, + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "ILLEGAL_PARAMETER" + ) + }) + + it("rejects a ServerHello missing the supported_versions extension", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + extensions: [testHelpers.makePreSharedKeyExtension(0)], + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "MISSING_EXTENSION" + ) + }) + + it("rejects a ServerHello with wrong TLS version in extension", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + extensions: [ + testHelpers.makeSupportedVersionsExtension(0x0303), + testHelpers.makePreSharedKeyExtension(0), + ], + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "ILLEGAL_PARAMETER" + ) + }) + + it("rejects a ServerHello missing the pre_shared_key extension", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + extensions: [ + testHelpers.makeSupportedVersionsExtension(0x0304), + ], + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "MISSING_EXTENSION" + ) + }) + + it("rejects a ServerHello with unsupported extensions", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + extensions: [ + testHelpers.makeSupportedVersionsExtension(0x0304), + testHelpers.makePreSharedKeyExtension(0), + testHelpers.makeCookieExtension(new Uint8Array([1, 2, 3])), + ], + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "UNSUPPORTED_EXTENSION" + ) + }) + + it("rejects a ServerHello that selects a non-zero PSK identity", async () => { + const badServerHello = testHelpers.makeServerHelloMessage({ + extensions: [ + testHelpers.makeSupportedVersionsExtension(0x0304), + testHelpers.makePreSharedKeyExtension(1), + ], + }) + const record = testHelpers.makePlaintextRecord({ + content: badServerHello, + type: 22, + }) + await assertThrowsAsync( + async () => { + await client.recv(record) + }, + TLSError, + "ILLEGAL_PARAMETER" + ) + }) + }) +}) + +describe("the ServerConnection class accepts extended ClientHellos", () => { + let server: ServerConnection + + beforeEach(async () => { + vi.spyOn(crypto, "getRandomValues").mockImplementation( + (arr: T): T => { + if (arr) { + ;(arr as unknown as Uint8Array).set(TEST_VECTORS.SERVER_RANDOM) + } + return arr + } + ) + server = await ServerConnection.create( + TEST_VECTORS.PSK, + TEST_VECTORS.PSK_ID, + () => {} + ) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it("can accept a ClientHello with many extensions", async () => { + const data = await server.recv(TEST_VECTORS.EXTENDED_CLIENT_HELLO) + expect(data).toBeNull() + }) +}) + +describe("a complete client-server handshake with live keys", () => { + it("completes a full handshake and exchanges data", async () => { + const psk = crypto.getRandomValues(new Uint8Array(32)) + const pskId = new TextEncoder().encode("test-channel") + + const clientToServer: Uint8Array[] = [] + const serverToClient: Uint8Array[] = [] + + const [client, server] = await Promise.all([ + ClientConnection.create(psk, pskId, (data) => { + clientToServer.push(data) + }), + ServerConnection.create(psk, pskId, (data) => { + serverToClient.push(data) + }), + ]) + + // Client sends ClientHello + expect(clientToServer.length).toBe(1) + + // Feed ClientHello to server + await server.recv(clientToServer[0]) + + // Server sends ServerHello + CCS + EE+Finished + expect(serverToClient.length).toBeGreaterThanOrEqual(2) + + // Feed all server messages to client + for (const msg of serverToClient) { + await client.recv(msg) + } + + // Client sends CCS + Finished + expect(clientToServer.length).toBeGreaterThanOrEqual(2) + + // Feed remaining client messages to server + for (let i = 1; i < clientToServer.length; i++) { + await server.recv(clientToServer[i]) + } + + // Both should be connected + await client.connected + await server.connected + + // Exchange application data + const clearClientToServer: Uint8Array[] = [] + const clearServerToClient: Uint8Array[] = [] + + const savedClientToServer = clientToServer.length + const savedServerToClient = serverToClient.length + + await client.send(new TextEncoder().encode("hello from client")) + const encryptedMsg = clientToServer[savedClientToServer] + const decrypted = await server.recv(encryptedMsg) + expect(decrypted).not.toBeNull() + expect(new TextDecoder().decode(decrypted!)).toBe("hello from client") + + await server.send(new TextEncoder().encode("hello from server")) + const encryptedMsg2 = serverToClient[savedServerToClient] + const decrypted2 = await client.recv(encryptedMsg2) + expect(decrypted2).not.toBeNull() + expect(new TextDecoder().decode(decrypted2!)).toBe("hello from server") + + // Clean close + await client.close() + void clearClientToServer + void clearServerToClient + }) +}) diff --git a/frontend/src/lib/pairing-channel/__tests__/utils.test.ts b/frontend/src/lib/pairing-channel/__tests__/utils.test.ts new file mode 100644 index 00000000..f917fe1c --- /dev/null +++ b/frontend/src/lib/pairing-channel/__tests__/utils.test.ts @@ -0,0 +1,511 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { describe, it, expect } from "vitest" +import { + bytesAreEqual, + zeros, + arrayToBytes, + BufferReader, + BufferWriter, + utf8ToBytes, + bytesToUtf8, + bytesToHex, +} from "../utils" +import { TLSError } from "../alerts" + +describe("bytesAreEqual", () => { + it("returns true for a variety of equal byte arrays", () => { + expect(bytesAreEqual(zeros(0), zeros(0))).toBe(true) + expect(bytesAreEqual(zeros(7), zeros(7))).toBe(true) + expect( + bytesAreEqual(arrayToBytes([1, 2, 3]), arrayToBytes([1, 2, 3])) + ).toBe(true) + }) + + it("returns false for a variety of non-equal byte arrays", () => { + expect(bytesAreEqual(zeros(0), zeros(1))).toBe(false) + expect(bytesAreEqual(zeros(1), zeros(0))).toBe(false) + expect( + bytesAreEqual(arrayToBytes([1, 2, 3]), arrayToBytes([2, 2, 3])) + ).toBe(false) + expect( + bytesAreEqual(arrayToBytes([1, 2, 3]), arrayToBytes([1, 1, 3])) + ).toBe(false) + expect( + bytesAreEqual(arrayToBytes([1, 2, 3]), arrayToBytes([1, 2, 4])) + ).toBe(false) + expect( + bytesAreEqual(arrayToBytes([1, 2, 3]), arrayToBytes([1, 2, 3, 4])) + ).toBe(false) + expect( + bytesAreEqual(arrayToBytes([1, 2, 3, 4]), arrayToBytes([1, 2, 3])) + ).toBe(false) + }) + + it("throws on a variety of bad inputs", () => { + expect(() => bytesAreEqual(0 as unknown as Uint8Array, 0 as unknown as Uint8Array)).toThrow() + expect(() => bytesAreEqual(null as unknown as Uint8Array, 0 as unknown as Uint8Array)).toThrow() + expect(() => + bytesAreEqual( + { some: "object" } as unknown as Uint8Array, + { another: "object" } as unknown as Uint8Array + ) + ).toThrow() + }) +}) + +describe("the BufferReader class", () => { + it("handles basic reading and seeking correctly", () => { + const buf = new BufferReader(utf8ToBytes("hello world")) + expect(buf.length()).toBe(11) + expect(buf.tell()).toBe(0) + expect(bytesToUtf8(buf.readBytes(5))).toBe("hello") + expect(buf.hasMoreBytes()).toBe(true) + expect(buf.tell()).toBe(5) + buf.incr(2) + expect(buf.tell()).toBe(7) + expect(buf.hasMoreBytes()).toBe(true) + expect(bytesToUtf8(buf.readBytes(4))).toBe("orld") + expect(buf.hasMoreBytes()).toBe(false) + buf.seek(2) + expect(buf.tell()).toBe(2) + expect(bytesToUtf8(buf.readBytes(5))).toBe("llo w") + }) + + it("errors if attempting to seek beyond the start of the buffer", () => { + const buf = new BufferReader(utf8ToBytes("hello world")) + expect(() => { + buf.seek(-1) + }).toThrow(TLSError) + expect(buf.tell()).toBe(0) + }) + + it("errors if attempting to seek beyond the end of the buffer", () => { + const buf = new BufferReader(utf8ToBytes("hello world")) + expect(() => { + buf.seek(12) + }).toThrow(TLSError) + expect(buf.tell()).toBe(0) + }) + + it("errors if attempting to read beyond the end of the buffer", () => { + const buf = new BufferReader(utf8ToBytes("hello world")) + buf.seek(2) + expect(() => { + buf.readBytes(12) + }).toThrow(TLSError) + expect(buf.tell()).toBe(2) + }) + + it("correctly reads integer primitives at various offsets", () => { + const buf = new BufferReader(arrayToBytes([132, 42, 17, 4, 0])) + expect(buf.readUint8()).toBe(132) + expect(buf.tell()).toBe(1) + expect(buf.readUint8()).toBe(42) + expect(buf.tell()).toBe(2) + expect(buf.readUint8()).toBe(17) + expect(buf.tell()).toBe(3) + expect(buf.readUint8()).toBe(4) + expect(buf.tell()).toBe(4) + expect(buf.readUint8()).toBe(0) + expect(buf.tell()).toBe(5) + + buf.seek(0) + buf.seek(0) + expect(buf.readUint16()).toBe(33834) + expect(buf.tell()).toBe(2) + buf.incr(-1) + expect(buf.readUint16()).toBe(10769) + expect(buf.tell()).toBe(3) + expect(buf.readUint16()).toBe(1024) + expect(buf.tell()).toBe(5) + + buf.seek(0) + expect(buf.readUint24()).toBe(8661521) + expect(buf.tell()).toBe(3) + buf.seek(1) + expect(buf.readUint24()).toBe(2756868) + expect(buf.tell()).toBe(4) + + buf.seek(0) + expect(buf.readUint32()).toBe(2217349380) + expect(buf.tell()).toBe(4) + buf.seek(1) + expect(buf.readUint32()).toBe(705758208) + expect(buf.tell()).toBe(5) + }) + + it("errors if reading integer primitives past the end of the buffer", () => { + const buf = new BufferReader(arrayToBytes([132, 42, 17, 4, 1])) + buf.seek(5) + expect(() => buf.readUint8()).toThrow(TLSError) + buf.seek(5) + expect(() => buf.readUint16()).toThrow(TLSError) + buf.seek(5) + expect(() => buf.readUint24()).toThrow(TLSError) + buf.seek(5) + expect(() => buf.readUint32()).toThrow(TLSError) + + buf.seek(4) + expect(buf.readUint8()).toBeTruthy() + buf.seek(4) + expect(() => buf.readUint16()).toThrow(TLSError) + buf.seek(4) + expect(() => buf.readUint24()).toThrow(TLSError) + buf.seek(4) + expect(() => buf.readUint32()).toThrow(TLSError) + + buf.seek(3) + expect(buf.readUint8()).toBeTruthy() + buf.seek(3) + expect(buf.readUint16()).toBeTruthy() + buf.seek(3) + expect(() => buf.readUint24()).toThrow(TLSError) + buf.seek(3) + expect(() => buf.readUint32()).toThrow(TLSError) + + buf.seek(2) + expect(buf.readUint8()).toBeTruthy() + buf.seek(2) + expect(buf.readUint16()).toBeTruthy() + buf.seek(2) + expect(buf.readUint24()).toBeTruthy() + buf.seek(2) + expect(() => buf.readUint32()).toThrow(TLSError) + }) + + it("correctly reads variable-length vectors of bytes", () => { + let buf = new BufferReader(arrayToBytes([4, 1, 2, 3, 4, 5])) + expect( + bytesAreEqual(buf.readVectorBytes8(), arrayToBytes([1, 2, 3, 4])) + ).toBe(true) + expect(buf.tell()).toBe(5) + buf = new BufferReader(arrayToBytes([0, 0, 0])) + expect(bytesAreEqual(buf.readVectorBytes8(), arrayToBytes([]))).toBe(true) + expect(buf.tell()).toBe(1) + + buf = new BufferReader(arrayToBytes([0, 4, 1, 2, 3, 4, 5])) + expect( + bytesAreEqual(buf.readVectorBytes16(), arrayToBytes([1, 2, 3, 4])) + ).toBe(true) + expect(buf.tell()).toBe(6) + + buf = new BufferReader(arrayToBytes([0, 0, 4, 1, 2, 3, 4, 5])) + expect( + bytesAreEqual(buf.readVectorBytes24(), arrayToBytes([1, 2, 3, 4])) + ).toBe(true) + expect(buf.tell()).toBe(7) + }) + + it("correctly reads variable-length vectors using a callback", () => { + let readValues: number[] = [] + let buf = new BufferReader(arrayToBytes([42, 4, 1, 2, 3, 4, 5])) + buf.seek(1) + buf.readVector8((contentsBuf, n) => { + expect(contentsBuf.length()).toBe(4) + expect(n).toBe(readValues.length) + readValues.push(contentsBuf.readUint8()) + }) + expect(readValues).toEqual([1, 2, 3, 4]) + expect(buf.tell()).toBe(6) + + readValues = [] + buf = new BufferReader(arrayToBytes([42, 0, 4, 1, 2, 3, 4, 5])) + buf.seek(1) + buf.readVector16((contentsBuf, n) => { + expect(contentsBuf.length()).toBe(4) + expect(n).toBe(readValues.length) + readValues.push(contentsBuf.readUint16()) + }) + expect(readValues).toEqual([(1 << 8) | 2, (3 << 8) | 4]) + expect(buf.tell()).toBe(7) + + readValues = [] + buf = new BufferReader(arrayToBytes([42, 0, 0, 4, 1, 2, 3, 4, 5])) + buf.seek(1) + buf.readVector24((contentsBuf, n) => { + expect(contentsBuf.length()).toBe(4) + expect(n).toBe(readValues.length) + readValues.push(contentsBuf.readUint8()) + }) + expect(readValues).toEqual([1, 2, 3, 4]) + expect(buf.tell()).toBe(8) + }) + + it("errors if a vector read consumes too many bytes", () => { + const buf = new BufferReader(arrayToBytes([2, 1, 2, 3])) + expect(() => { + buf.readVector8((contentsBuf) => { + expect(contentsBuf.length()).toBe(2) + contentsBuf.readUint24() + }) + }).toThrow(TLSError) + }) + + it("errors if a vector read somehow consumes too few bytes", () => { + const buf = new BufferReader(arrayToBytes([3, 1, 2, 3])) + expect(() => { + buf.readVector8((contentsBuf) => { + expect(contentsBuf.length()).toBe(3) + expect(contentsBuf.readUint8()).toBe(1) + expect(contentsBuf.readUint8()).toBe(2) + expect(contentsBuf.readUint8()).toBe(3) + // simulate some bug that changes the underlying buffer. + buf.incr(-1) + }) + }).toThrow(TLSError) + }) + + it("errors if a vector read consumes no bytes", () => { + const buf = new BufferReader(arrayToBytes([3, 1, 2, 3])) + expect(() => { + buf.readVector8((contentsBuf) => { + expect(contentsBuf.length()).toBe(3) + // don't consume anything, risking an infinite loop. + }) + }).toThrow(TLSError) + }) + + it("errors if a nested vector read would exceed the outer buffer length", () => { + // A vector of length 5, inside a vector of length 3. + const buf = new BufferReader(arrayToBytes([3, 5, 1, 2, 3, 4, 5])) + expect(() => { + buf.readVector8((contentsBuf) => { + expect(contentsBuf.length()).toBe(3) + contentsBuf.readVector8(() => { + expect.fail("the callback should not get called") + }) + }) + }).toThrow(TLSError) + }) +}) + +describe("the BufferWriter class", () => { + it("grows appropriately as data is written", () => { + const buf = new BufferWriter(2) + buf.writeBytes(arrayToBytes([1, 2, 3, 4, 5])) + expect(buf.tell()).toBe(5) + expect(buf.length()).toBe(6) + }) + + it("can read back written data using `slice`", () => { + const buf = new BufferWriter(2) + buf.writeBytes(arrayToBytes([1, 2, 3, 4, 5])) + expect(buf.tell()).toBe(5) + expect(bytesAreEqual(buf.slice(), arrayToBytes([1, 2, 3, 4, 5]))).toBe(true) + expect(buf.tell()).toBe(5) + expect(bytesAreEqual(buf.slice(1), arrayToBytes([2, 3, 4, 5]))).toBe(true) + expect(buf.tell()).toBe(5) + expect(bytesAreEqual(buf.slice(1, 3), arrayToBytes([2, 3]))).toBe(true) + expect(buf.tell()).toBe(5) + expect(bytesAreEqual(buf.slice(1, -1), arrayToBytes([2, 3, 4]))).toBe(true) + expect(buf.tell()).toBe(5) + }) + + it("refuses to slice past the start of the buffer", () => { + const buf = new BufferWriter(2) + buf.writeBytes(arrayToBytes([1, 2, 3, 4, 5])) + expect(() => buf.slice(-1)).toThrow(TLSError) + expect(() => buf.slice(0, -50)).toThrow(TLSError) + }) + + it("refuses to slice past the end of the buffer", () => { + const buf = new BufferWriter(2) + buf.writeBytes(arrayToBytes([1, 2, 3, 4, 5])) + expect(() => buf.slice(2, 50)).toThrow(TLSError) + }) + + it("returns and resets the buffer on flush", () => { + const buf = new BufferWriter() + buf.writeBytes(arrayToBytes([1, 2, 3, 4, 5])) + expect(bytesAreEqual(buf.flush(), arrayToBytes([1, 2, 3, 4, 5]))).toBe(true) + expect(buf.tell()).toBe(0) + expect(bytesAreEqual(buf.slice(), arrayToBytes([]))).toBe(true) + }) + + it("truncates at the current position on flush", () => { + const buf = new BufferWriter() + buf.writeBytes(arrayToBytes([1, 2, 3, 4, 5])) + buf.incr(-2) + expect(bytesAreEqual(buf.flush(), arrayToBytes([1, 2, 3]))).toBe(true) + expect(buf.tell()).toBe(0) + }) + + it("correctly writes integer primitives at various offsets", () => { + const buf = new BufferWriter() + for (let i = 0; i < 10; i++) { + buf.writeUint8(i) + } + expect(buf.tell()).toBe(10) + expect( + bytesAreEqual( + buf.flush(), + arrayToBytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + ) + ).toBe(true) + + for (let i = 0; i < 9; i++) { + buf.writeUint16(i) + } + buf.writeUint16(3079) + expect(buf.tell()).toBe(20) + expect( + bytesAreEqual( + buf.flush(), + arrayToBytes([ + 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 12, 7, + ]) + ) + ).toBe(true) + + for (let i = 0; i < 5; i++) { + buf.writeUint24(i) + } + buf.writeUint24(788229) + expect(buf.tell()).toBe(18) + expect( + bytesAreEqual( + buf.flush(), + arrayToBytes([ + 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 12, 7, 5, + ]) + ) + ).toBe(true) + + for (let i = 0; i < 5; i++) { + buf.writeUint32(i) + } + buf.writeUint32(201786627) + expect(buf.tell()).toBe(24) + expect( + bytesAreEqual( + buf.flush(), + arrayToBytes([ + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 12, + 7, 5, 3, + ]) + ) + ).toBe(true) + + buf.writeUint8(1) + buf.writeUint16(2) + buf.writeUint24(3) + buf.writeUint8(4) + buf.writeUint32(5) + expect( + bytesAreEqual( + buf.flush(), + arrayToBytes([1, 0, 2, 0, 0, 3, 4, 0, 0, 0, 5]) + ) + ).toBe(true) + }) + + it("correctly writes variable-length vectors of bytes", () => { + let buf = new BufferWriter() + buf.writeVectorBytes8(arrayToBytes([1, 2, 3, 4, 5])) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([5, 1, 2, 3, 4, 5])) + ).toBe(true) + + buf = new BufferWriter() + buf.writeVectorBytes16(arrayToBytes([1, 2, 3, 4, 5])) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([0, 5, 1, 2, 3, 4, 5])) + ).toBe(true) + + buf = new BufferWriter() + buf.writeVectorBytes24(arrayToBytes([1, 2, 3, 4, 5])) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([0, 0, 5, 1, 2, 3, 4, 5])) + ).toBe(true) + }) + + it("correctly writes variable-length vectors using a callback", () => { + let buf = new BufferWriter() + buf.writeVector8((buf) => { + buf.writeUint8(1) + buf.writeUint8(2) + buf.writeUint8(3) + buf.writeUint8(4) + buf.writeUint8(5) + }) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([5, 1, 2, 3, 4, 5])) + ).toBe(true) + + buf = new BufferWriter() + buf.writeVector16((buf) => { + buf.writeUint8(1) + buf.writeUint8(2) + buf.writeUint8(3) + buf.writeUint8(4) + buf.writeUint8(5) + }) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([0, 5, 1, 2, 3, 4, 5])) + ).toBe(true) + + buf = new BufferWriter() + buf.writeVector24((buf) => { + buf.writeUint8(1) + buf.writeUint8(2) + buf.writeUint8(3) + buf.writeUint8(4) + buf.writeUint8(5) + }) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([0, 0, 5, 1, 2, 3, 4, 5])) + ).toBe(true) + }) + + it("correctly writes nested variable-length vectors using nested callback", () => { + const buf = new BufferWriter() + buf.writeVector16((buf) => { + buf.writeVector8((buf) => { + buf.writeUint8(1) + buf.writeUint8(2) + buf.writeUint8(3) + buf.writeUint8(4) + buf.writeUint8(5) + }) + }) + expect( + bytesAreEqual(buf.flush(), arrayToBytes([0, 6, 5, 1, 2, 3, 4, 5])) + ).toBe(true) + }) + + it("errors if a vector write exceeds the maximum size representable in its length field", () => { + let buf = new BufferWriter() + buf.writeVectorBytes8(zeros(255)) + expect(() => buf.writeVectorBytes8(zeros(256))).toThrow(TLSError) + + buf = new BufferWriter() + buf.writeVectorBytes16(zeros(65535)) + expect(() => buf.writeVectorBytes16(zeros(65536))).toThrow(TLSError) + + // Skip the 24-bit test as it requires 16MB allocation + + buf = new BufferWriter() + expect(() => { + buf.writeVector8((buf) => { + buf.writeBytes(zeros(256)) + }) + }).toThrow(TLSError) + + buf = new BufferWriter() + expect(() => { + buf.writeVector16((buf) => { + buf.writeBytes(zeros(65536)) + }) + }).toThrow(TLSError) + }) +}) + +// Ensure bytesToHex works (used in tests) +describe("bytesToHex", () => { + it("converts bytes to hex string", () => { + expect(bytesToHex(arrayToBytes([0, 1, 255]))).toBe("0001ff") + }) +}) diff --git a/frontend/src/lib/pairing-channel/alerts.ts b/frontend/src/lib/pairing-channel/alerts.ts new file mode 100644 index 00000000..f22f0625 --- /dev/null +++ b/frontend/src/lib/pairing-channel/alerts.ts @@ -0,0 +1,78 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +export const ALERT_LEVEL = { + WARNING: 1, + FATAL: 2, +} as const + +export const ALERT_DESCRIPTION = { + CLOSE_NOTIFY: 0, + UNEXPECTED_MESSAGE: 10, + BAD_RECORD_MAC: 20, + RECORD_OVERFLOW: 22, + HANDSHAKE_FAILURE: 40, + ILLEGAL_PARAMETER: 47, + DECODE_ERROR: 50, + DECRYPT_ERROR: 51, + PROTOCOL_VERSION: 70, + INTERNAL_ERROR: 80, + MISSING_EXTENSION: 109, + UNSUPPORTED_EXTENSION: 110, + UNKNOWN_PSK_IDENTITY: 115, + NO_APPLICATION_PROTOCOL: 120, +} as const + +function alertTypeToName(type: number): string { + for (const name in ALERT_DESCRIPTION) { + if ( + ALERT_DESCRIPTION[name as keyof typeof ALERT_DESCRIPTION] === type + ) { + return `${name} (${type})` + } + } + return `UNKNOWN (${type})` +} + +export class TLSAlert extends Error { + description: number + level: number + + constructor(description: number, level: number) { + super(`TLS Alert: ${alertTypeToName(description)}`) + this.description = description + this.level = level + } + + static fromBytes(bytes: Uint8Array): TLSAlert { + if (bytes.byteLength !== 2) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + switch (bytes[1]) { + case ALERT_DESCRIPTION.CLOSE_NOTIFY: + if (bytes[0] !== ALERT_LEVEL.WARNING) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + return new TLSCloseNotify() + default: + return new TLSError(bytes[1]) + } + } + + toBytes(): Uint8Array { + return new Uint8Array([this.level, this.description]) + } +} + +export class TLSCloseNotify extends TLSAlert { + constructor() { + super(ALERT_DESCRIPTION.CLOSE_NOTIFY, ALERT_LEVEL.WARNING) + } +} + +export class TLSError extends TLSAlert { + constructor(description: number = ALERT_DESCRIPTION.INTERNAL_ERROR) { + super(description, ALERT_LEVEL.FATAL) + } +} diff --git a/frontend/src/lib/pairing-channel/constants.ts b/frontend/src/lib/pairing-channel/constants.ts new file mode 100644 index 00000000..13cb9e08 --- /dev/null +++ b/frontend/src/lib/pairing-channel/constants.ts @@ -0,0 +1,9 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +export const VERSION_TLS_1_0 = 0x0301 +export const VERSION_TLS_1_2 = 0x0303 +export const VERSION_TLS_1_3 = 0x0304 +export const TLS_AES_128_GCM_SHA256 = 0x1301 +export const PSK_MODE_KE = 0 diff --git a/frontend/src/lib/pairing-channel/crypto.ts b/frontend/src/lib/pairing-channel/crypto.ts new file mode 100644 index 00000000..03b63946 --- /dev/null +++ b/frontend/src/lib/pairing-channel/crypto.ts @@ -0,0 +1,161 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { utf8ToBytes, BufferWriter } from "./utils" +import { ALERT_DESCRIPTION, TLSError } from "./alerts" + +export const AEAD_SIZE_INFLATION = 16 +export const KEY_LENGTH = 16 +export const IV_LENGTH = 12 +export const HASH_LENGTH = 32 + +// Helper to ensure Uint8Array has an ArrayBuffer (not SharedArrayBuffer) backing, +// which is required by WebCrypto APIs in strict TypeScript. +function toBuffer(bytes: Uint8Array): Uint8Array { + return bytes as Uint8Array +} + +export async function prepareKey( + key: Uint8Array, + mode: "encrypt" | "decrypt" +): Promise { + return crypto.subtle.importKey("raw", toBuffer(key), { name: "AES-GCM" }, false, [mode]) +} + +export async function encrypt( + key: CryptoKey, + iv: Uint8Array, + plaintext: Uint8Array, + additionalData: Uint8Array +): Promise { + const ciphertext = await crypto.subtle.encrypt( + { + additionalData: toBuffer(additionalData), + iv: toBuffer(iv), + name: "AES-GCM", + tagLength: AEAD_SIZE_INFLATION * 8, + }, + key, + toBuffer(plaintext) + ) + return new Uint8Array(ciphertext) +} + +export async function decrypt( + key: CryptoKey, + iv: Uint8Array, + ciphertext: Uint8Array, + additionalData: Uint8Array +): Promise { + try { + const plaintext = await crypto.subtle.decrypt( + { + additionalData: toBuffer(additionalData), + iv: toBuffer(iv), + name: "AES-GCM", + tagLength: AEAD_SIZE_INFLATION * 8, + }, + key, + toBuffer(ciphertext) + ) + return new Uint8Array(plaintext) + } catch { + throw new TLSError(ALERT_DESCRIPTION.BAD_RECORD_MAC) + } +} + +export async function hash(message: Uint8Array): Promise { + return new Uint8Array( + await crypto.subtle.digest({ name: "SHA-256" }, toBuffer(message)) + ) +} + +export async function hmac( + keyBytes: Uint8Array, + message: Uint8Array +): Promise { + const key = await crypto.subtle.importKey( + "raw", + toBuffer(keyBytes), + { + hash: { name: "SHA-256" }, + name: "HMAC", + }, + false, + ["sign"] + ) + const sig = await crypto.subtle.sign({ name: "HMAC" }, key, toBuffer(message)) + return new Uint8Array(sig) +} + +export async function verifyHmac( + keyBytes: Uint8Array, + signature: Uint8Array, + message: Uint8Array +): Promise { + const key = await crypto.subtle.importKey( + "raw", + toBuffer(keyBytes), + { + hash: { name: "SHA-256" }, + name: "HMAC", + }, + false, + ["verify"] + ) + if (!(await crypto.subtle.verify({ name: "HMAC" }, key, toBuffer(signature), toBuffer(message)))) { + throw new TLSError(ALERT_DESCRIPTION.DECRYPT_ERROR) + } +} + +export async function hkdfExtract( + salt: Uint8Array, + ikm: Uint8Array +): Promise { + return await hmac(salt, ikm) +} + +export async function hkdfExpand( + prk: Uint8Array, + info: Uint8Array, + length: number +): Promise { + const N = Math.ceil(length / HASH_LENGTH) + if (N <= 0) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + if (N >= 255) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + const input = new BufferWriter() + const output = new BufferWriter() + let T: Uint8Array = new Uint8Array(0) + for (let i = 1; i <= N; i++) { + input.writeBytes(T) + input.writeBytes(info) + input.writeUint8(i) + T = await hmac(prk, input.flush()) + output.writeBytes(T) + } + return output.slice(0, length) +} + +export async function hkdfExpandLabel( + secret: Uint8Array, + label: string, + context: Uint8Array, + length: number +): Promise { + const hkdfLabel = new BufferWriter() + hkdfLabel.writeUint16(length) + hkdfLabel.writeVectorBytes8(utf8ToBytes("tls13 " + label)) + hkdfLabel.writeVectorBytes8(context) + return hkdfExpand(secret, hkdfLabel.flush(), length) +} + +export async function getRandomBytes(size: number): Promise> { + const bytes = new Uint8Array(size) as Uint8Array + crypto.getRandomValues(bytes) + return bytes +} diff --git a/frontend/src/lib/pairing-channel/extensions.ts b/frontend/src/lib/pairing-channel/extensions.ts new file mode 100644 index 00000000..c6dc75e8 --- /dev/null +++ b/frontend/src/lib/pairing-channel/extensions.ts @@ -0,0 +1,245 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { ALERT_DESCRIPTION, TLSError } from "./alerts" +import { HANDSHAKE_TYPE } from "./messages" +import { HASH_LENGTH } from "./crypto" +import type { BufferReader, BufferWriter } from "./utils" + +export const EXTENSION_TYPE = { + PRE_SHARED_KEY: 41, + SUPPORTED_VERSIONS: 43, + PSK_KEY_EXCHANGE_MODES: 45, +} as const + +export interface ExtensionLike { + TYPE_TAG: number + write(messageType: number, buf: BufferWriter): void +} + +export class Extension { + get TYPE_TAG(): number { + throw new Error("not implemented") + } + + static read(messageType: number, buf: BufferReader): ExtensionLike { + const type = buf.readUint16() + let ext: ExtensionLike = { + TYPE_TAG: type, + write() { + throw new Error("not implemented") + }, + } + buf.readVector16((buf) => { + switch (type) { + case EXTENSION_TYPE.PRE_SHARED_KEY: + ext = PreSharedKeyExtension._read(messageType, buf) + break + case EXTENSION_TYPE.SUPPORTED_VERSIONS: + ext = SupportedVersionsExtension._read(messageType, buf) + break + case EXTENSION_TYPE.PSK_KEY_EXCHANGE_MODES: + ext = PskKeyExchangeModesExtension._read(messageType, buf) + break + default: + // Skip over unrecognised extensions. + buf.incr(buf.length()) + } + if (buf.hasMoreBytes()) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + }) + return ext + } + + write(messageType: number, buf: BufferWriter): void { + buf.writeUint16(this.TYPE_TAG) + buf.writeVector16((buf) => { + this._write(messageType, buf) + }) + } + + static _read(_messageType: number, _buf: BufferReader): Extension { + throw new Error("not implemented") + } + + _write(_messageType: number, _buf: BufferWriter): void { + throw new Error("not implemented") + } +} + +export class PreSharedKeyExtension extends Extension { + identities: Uint8Array[] | null + binders: Uint8Array[] | null + selectedIdentity: number | null + + constructor( + identities: Uint8Array[] | null, + binders: Uint8Array[] | null, + selectedIdentity: number | null + ) { + super() + this.identities = identities + this.binders = binders + this.selectedIdentity = selectedIdentity + } + + get TYPE_TAG(): number { + return EXTENSION_TYPE.PRE_SHARED_KEY + } + + static _read(messageType: number, buf: BufferReader): PreSharedKeyExtension { + let identities: Uint8Array[] | null = null, + binders: Uint8Array[] | null = null, + selectedIdentity: number | null = null + switch (messageType) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + identities = [] + binders = [] + buf.readVector16((buf) => { + const identity = buf.readVectorBytes16() + buf.readBytes(4) // Skip over the ticket age. + identities!.push(identity) + }) + buf.readVector16((buf) => { + const binder = buf.readVectorBytes8() + if (binder.byteLength < HASH_LENGTH) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + binders!.push(binder) + }) + if (identities.length !== binders.length) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + break + case HANDSHAKE_TYPE.SERVER_HELLO: + selectedIdentity = buf.readUint16() + break + default: + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + return new this(identities, binders, selectedIdentity) + } + + _write(messageType: number, buf: BufferWriter): void { + switch (messageType) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + buf.writeVector16((buf) => { + this.identities!.forEach((pskId) => { + buf.writeVectorBytes16(pskId) + buf.writeUint32(0) + }) + }) + buf.writeVector16((buf) => { + this.binders!.forEach((pskBinder) => { + buf.writeVectorBytes8(pskBinder) + }) + }) + break + case HANDSHAKE_TYPE.SERVER_HELLO: + buf.writeUint16(this.selectedIdentity!) + break + default: + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + } +} + +export class SupportedVersionsExtension extends Extension { + versions: number[] | null + selectedVersion: number | null + + constructor(versions: number[] | null, selectedVersion?: number | null) { + super() + this.versions = versions + this.selectedVersion = selectedVersion ?? null + } + + get TYPE_TAG(): number { + return EXTENSION_TYPE.SUPPORTED_VERSIONS + } + + static _read( + messageType: number, + buf: BufferReader + ): SupportedVersionsExtension { + let versions: number[] | null = null, + selectedVersion: number | null = null + switch (messageType) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + versions = [] + buf.readVector8((buf) => { + versions!.push(buf.readUint16()) + }) + break + case HANDSHAKE_TYPE.SERVER_HELLO: + selectedVersion = buf.readUint16() + break + default: + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + return new this(versions, selectedVersion) + } + + _write(messageType: number, buf: BufferWriter): void { + switch (messageType) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + buf.writeVector8((buf) => { + this.versions!.forEach((version) => { + buf.writeUint16(version) + }) + }) + break + case HANDSHAKE_TYPE.SERVER_HELLO: + buf.writeUint16(this.selectedVersion!) + break + default: + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + } +} + +export class PskKeyExchangeModesExtension extends Extension { + modes: number[] + + constructor(modes: number[]) { + super() + this.modes = modes + } + + get TYPE_TAG(): number { + return EXTENSION_TYPE.PSK_KEY_EXCHANGE_MODES + } + + static _read( + messageType: number, + buf: BufferReader + ): PskKeyExchangeModesExtension { + const modes: number[] = [] + switch (messageType) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + buf.readVector8((buf) => { + modes.push(buf.readUint8()) + }) + break + default: + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + return new this(modes) + } + + _write(messageType: number, buf: BufferWriter): void { + switch (messageType) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + buf.writeVector8((buf) => { + this.modes.forEach((mode) => { + buf.writeUint8(mode) + }) + }) + break + default: + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + } +} diff --git a/frontend/src/lib/pairing-channel/index.ts b/frontend/src/lib/pairing-channel/index.ts new file mode 100644 index 00000000..0a55485e --- /dev/null +++ b/frontend/src/lib/pairing-channel/index.ts @@ -0,0 +1,237 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// A wrapper that combines a WebSocket to the channelserver +// with some client-side encryption for securing the channel. + +import { ClientConnection, ServerConnection } from "./tlsconnection" +import { TLSCloseNotify, TLSError } from "./alerts" +import { + base64urlToBytes, + bytesToBase64url, + bytesToHex, + bytesToUtf8, + hexToBytes, + utf8ToBytes, +} from "./utils" + +const CLOSE_FLUSH_BUFFER_INTERVAL_MS = 200 +const CLOSE_FLUSH_BUFFER_MAX_TRIES = 5 + +export class PairingChannel extends EventTarget { + _channelId: string + _channelKey: Uint8Array + _socket: WebSocket | null + _connection: ClientConnection | ServerConnection | null + _selfClosed: boolean + _peerClosed: boolean + + constructor( + channelId: string, + channelKey: Uint8Array, + socket: WebSocket, + connection: ClientConnection | ServerConnection + ) { + super() + this._channelId = channelId + this._channelKey = channelKey + this._socket = socket + this._connection = connection + this._selfClosed = false + this._peerClosed = false + this._setupListeners() + } + + /** + * Create a new pairing channel. + * + * This will open a channel on the channelserver, and generate a random client-side + * encryption key. When the promise resolves, `this.channelId` and `this.channelKey` + * can be transferred to another client to allow it to securely connect to the channel. + * + * @returns Promise + */ + static create(channelServerURI: string): Promise { + const wsURI = channelServerURI.replace(/\/+$/, "") + const channelKey = crypto.getRandomValues(new Uint8Array(32)) + // The one who creates the channel plays the role of 'server' in the underlying TLS exchange. + return this._makePairingChannel(wsURI, ServerConnection, channelKey) + } + + /** + * Connect to an existing pairing channel. + * + * This will connect to a channel on the channelserver previously established by + * another client calling `create`. The `channelId` and `channelKey` must have been + * obtained via some out-of-band mechanism (such as by scanning from a QR code). + * + * @returns Promise + */ + static connect( + channelServerURI: string, + channelId: string, + channelKey: Uint8Array + ): Promise { + const wsURI = `${channelServerURI.replace(/\/+$/, "")}?channelId=${channelId}` + // The one who connects to an existing channel plays the role of 'client' + // in the underlying TLS exchange. + return this._makePairingChannel(wsURI, ClientConnection, channelKey) + } + + static _makePairingChannel( + wsUri: string, + ConnectionClass: typeof ClientConnection | typeof ServerConnection, + psk: Uint8Array + ): Promise { + const socket = new WebSocket(wsUri) + return new Promise((resolve, reject) => { + let stopListening: () => void + const onConnectionError = async () => { + stopListening() + reject(new Error("Error while creating the pairing channel")) + } + const onFirstMessage = async (event: MessageEvent) => { + stopListening() + try { + const { channelid: channelId } = JSON.parse(event.data) + const pskId = utf8ToBytes(channelId) + const connection = await ConnectionClass.create( + psk, + pskId, + (data: Uint8Array) => { + socket.send(bytesToBase64url(data)) + } + ) + const instance = new this(channelId, psk, socket, connection) + resolve(instance) + } catch (err) { + reject(err) + } + } + stopListening = () => { + socket.removeEventListener("close", onConnectionError) + socket.removeEventListener("error", onConnectionError) + socket.removeEventListener("message", onFirstMessage) + } + socket.addEventListener("close", onConnectionError) + socket.addEventListener("error", onConnectionError) + socket.addEventListener("message", onFirstMessage) + }) + } + + _setupListeners(): void { + this._socket!.addEventListener("message", async (event: MessageEvent) => { + try { + const channelServerEnvelope = JSON.parse(event.data) + const payload = await this._connection!.recv( + base64urlToBytes(channelServerEnvelope.message) + ) + if (payload !== null) { + const data = JSON.parse(bytesToUtf8(payload)) + this.dispatchEvent( + new CustomEvent("message", { + detail: { + data, + sender: channelServerEnvelope.sender, + }, + }) + ) + } + } catch (error) { + let event: CustomEvent + if (error instanceof TLSCloseNotify) { + this._peerClosed = true + if (this._selfClosed) { + this._shutdown() + } + event = new CustomEvent("close") + } else { + event = new CustomEvent("error", { + detail: { + error, + }, + }) + } + this.dispatchEvent(event) + } + }) + this._socket!.addEventListener("error", () => { + this._shutdown() + this.dispatchEvent( + new CustomEvent("error", { + detail: { + error: new Error("WebSocket error."), + }, + }) + ) + }) + this._socket!.addEventListener("close", () => { + this._shutdown() + if (!this._peerClosed) { + this.dispatchEvent( + new CustomEvent("error", { + detail: { + error: new Error("WebSocket unexpectedly closed"), + }, + }) + ) + } + }) + } + + async send(data: Record): Promise { + const payload = utf8ToBytes(JSON.stringify(data)) + await this._connection!.send(payload) + } + + async close(): Promise { + this._selfClosed = true + await this._connection!.close() + try { + let tries = 0 + while (this._socket!.bufferedAmount > 0) { + if (++tries > CLOSE_FLUSH_BUFFER_MAX_TRIES) { + throw new Error("Could not flush the outgoing buffer in time.") + } + await new Promise((res) => setTimeout(res, CLOSE_FLUSH_BUFFER_INTERVAL_MS)) + } + } finally { + if (this._peerClosed) { + this._shutdown() + } + } + } + + _shutdown(): void { + if (this._socket) { + this._socket.close() + this._socket = null + this._connection = null + } + } + + get closed(): boolean { + return !this._socket || this._socket.readyState === 3 + } + + get channelId(): string { + return this._channelId + } + + get channelKey(): Uint8Array { + return this._channelKey + } +} + +// Re-export helpful utilities for calling code to use. +export { + base64urlToBytes, + bytesToBase64url, + bytesToHex, + bytesToUtf8, + hexToBytes, + TLSCloseNotify, + TLSError, + utf8ToBytes, +} diff --git a/frontend/src/lib/pairing-channel/keyschedule.ts b/frontend/src/lib/pairing-channel/keyschedule.ts new file mode 100644 index 00000000..5eac2adc --- /dev/null +++ b/frontend/src/lib/pairing-channel/keyschedule.ts @@ -0,0 +1,132 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { BufferWriter, EMPTY, zeros } from "./utils" +import { ALERT_DESCRIPTION, TLSError } from "./alerts" +import { + hkdfExtract, + hkdfExpandLabel, + HASH_LENGTH, + hash, + hmac, + verifyHmac, +} from "./crypto" + +const STAGE_UNINITIALIZED = 0 +const STAGE_EARLY_SECRET = 1 +const STAGE_HANDSHAKE_SECRET = 2 +const STAGE_MASTER_SECRET = 3 + +export class KeySchedule { + stage: number + transcript: BufferWriter + secret: Uint8Array | null + extBinderKey: Uint8Array | null + clientHandshakeTrafficSecret: Uint8Array | null + serverHandshakeTrafficSecret: Uint8Array | null + clientApplicationTrafficSecret: Uint8Array | null + serverApplicationTrafficSecret: Uint8Array | null + + constructor() { + this.stage = STAGE_UNINITIALIZED + this.transcript = new BufferWriter() + this.secret = null + this.extBinderKey = null + this.clientHandshakeTrafficSecret = null + this.serverHandshakeTrafficSecret = null + this.clientApplicationTrafficSecret = null + this.serverApplicationTrafficSecret = null + } + + async addPSK(psk: Uint8Array | null): Promise { + if (psk === null) { + psk = zeros(HASH_LENGTH) + } + if (this.stage !== STAGE_UNINITIALIZED) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + this.stage = STAGE_EARLY_SECRET + this.secret = await hkdfExtract(zeros(HASH_LENGTH), psk) + this.extBinderKey = await this.deriveSecret("ext binder", EMPTY) + this.secret = await this.deriveSecret("derived", EMPTY) + } + + async addECDHE(ecdhe: Uint8Array | null): Promise { + if (ecdhe === null) { + ecdhe = zeros(HASH_LENGTH) + } + if (this.stage !== STAGE_EARLY_SECRET) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + this.stage = STAGE_HANDSHAKE_SECRET + this.extBinderKey = null + this.secret = await hkdfExtract(this.secret!, ecdhe) + this.clientHandshakeTrafficSecret = await this.deriveSecret("c hs traffic") + this.serverHandshakeTrafficSecret = await this.deriveSecret("s hs traffic") + this.secret = await this.deriveSecret("derived", EMPTY) + } + + async finalize(): Promise { + if (this.stage !== STAGE_HANDSHAKE_SECRET) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + this.stage = STAGE_MASTER_SECRET + this.clientHandshakeTrafficSecret = null + this.serverHandshakeTrafficSecret = null + this.secret = await hkdfExtract(this.secret!, zeros(HASH_LENGTH)) + this.clientApplicationTrafficSecret = await this.deriveSecret("c ap traffic") + this.serverApplicationTrafficSecret = await this.deriveSecret("s ap traffic") + this.secret = null + } + + addToTranscript(bytes: Uint8Array): void { + this.transcript.writeBytes(bytes) + } + + getTranscript(): Uint8Array { + return this.transcript.slice() + } + + async deriveSecret( + label: string, + transcript?: Uint8Array + ): Promise { + transcript = transcript || this.getTranscript() + return await hkdfExpandLabel( + this.secret!, + label, + await hash(transcript), + HASH_LENGTH + ) + } + + async calculateFinishedMAC( + baseKey: Uint8Array, + transcript?: Uint8Array + ): Promise { + transcript = transcript || this.getTranscript() + const finishedKey = await hkdfExpandLabel( + baseKey, + "finished", + EMPTY, + HASH_LENGTH + ) + return await hmac(finishedKey, await hash(transcript)) + } + + async verifyFinishedMAC( + baseKey: Uint8Array, + mac: Uint8Array, + transcript?: Uint8Array + ): Promise { + transcript = transcript || this.getTranscript() + const finishedKey = await hkdfExpandLabel( + baseKey, + "finished", + EMPTY, + HASH_LENGTH + ) + await verifyHmac(finishedKey, mac, await hash(transcript)) + } +} diff --git a/frontend/src/lib/pairing-channel/messages.ts b/frontend/src/lib/pairing-channel/messages.ts new file mode 100644 index 00000000..d18718b2 --- /dev/null +++ b/frontend/src/lib/pairing-channel/messages.ts @@ -0,0 +1,374 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { BufferWriter, BufferReader } from "./utils" +import { ALERT_DESCRIPTION, TLSError } from "./alerts" +import { HASH_LENGTH } from "./crypto" +import { + Extension, + EXTENSION_TYPE, + type ExtensionLike, +} from "./extensions" +import { + VERSION_TLS_1_2, + VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, + VERSION_TLS_1_0, +} from "./constants" + +export const HANDSHAKE_TYPE = { + CLIENT_HELLO: 1, + SERVER_HELLO: 2, + NEW_SESSION_TICKET: 4, + ENCRYPTED_EXTENSIONS: 8, + FINISHED: 20, +} as const + +type ExtensionMap = Map & { lastSeenExtension?: number } + +export class HandshakeMessage { + get TYPE_TAG(): number { + throw new Error("not implemented") + } + + static fromBytes(bytes: Uint8Array): HandshakeMessage { + const buf = new BufferReader(bytes) + const msg = this.read(buf) + if (buf.hasMoreBytes()) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + return msg + } + + toBytes(): Uint8Array { + const buf = new BufferWriter() + this.write(buf) + return buf.flush() + } + + static read(buf: BufferReader): HandshakeMessage { + const type = buf.readUint8() + let msg: HandshakeMessage | null = null + buf.readVector24((buf) => { + switch (type) { + case HANDSHAKE_TYPE.CLIENT_HELLO: + msg = ClientHello._read(buf) + break + case HANDSHAKE_TYPE.SERVER_HELLO: + msg = ServerHello._read(buf) + break + case HANDSHAKE_TYPE.NEW_SESSION_TICKET: + msg = NewSessionTicket._read(buf) + break + case HANDSHAKE_TYPE.ENCRYPTED_EXTENSIONS: + msg = EncryptedExtensions._read(buf) + break + case HANDSHAKE_TYPE.FINISHED: + msg = Finished._read(buf) + break + } + if (buf.hasMoreBytes()) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + }) + if (msg === null) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + return msg + } + + write(buf: BufferWriter): void { + buf.writeUint8(this.TYPE_TAG) + buf.writeVector24((buf) => { + this._write(buf) + }) + } + + static _read(_buf: BufferReader): HandshakeMessage { + throw new Error("not implemented") + } + + _write(_buf: BufferWriter): void { + throw new Error("not implemented") + } + + static _readExtensions( + messageType: number, + buf: BufferReader + ): ExtensionMap { + const extensions: ExtensionMap = new Map() + buf.readVector16((buf) => { + const ext = Extension.read(messageType, buf) + if (extensions.has(ext.TYPE_TAG)) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + extensions.set(ext.TYPE_TAG, ext) + extensions.lastSeenExtension = ext.TYPE_TAG + }) + return extensions + } + + _writeExtensions(buf: BufferWriter, extensions: ExtensionLike[]): void { + buf.writeVector16((buf) => { + extensions.forEach((ext) => { + ext.write(this.TYPE_TAG, buf) + }) + }) + } +} + +export class ClientHello extends HandshakeMessage { + random: Uint8Array + sessionId: Uint8Array + extensions: ExtensionMap + + constructor( + random: Uint8Array, + sessionId: Uint8Array, + extensions: ExtensionMap | ExtensionLike[] + ) { + super() + this.random = random + this.sessionId = sessionId + if (Array.isArray(extensions)) { + const map: ExtensionMap = new Map() + for (const ext of extensions) { + map.set(ext.TYPE_TAG, ext) + map.lastSeenExtension = ext.TYPE_TAG + } + this.extensions = map + } else { + this.extensions = extensions + } + } + + get TYPE_TAG(): number { + return HANDSHAKE_TYPE.CLIENT_HELLO + } + + static _read(buf: BufferReader): ClientHello { + if (buf.readUint16() < VERSION_TLS_1_0) { + throw new TLSError(ALERT_DESCRIPTION.PROTOCOL_VERSION) + } + const random = buf.readBytes(32) + const sessionId = buf.readVectorBytes8() + let found = false + buf.readVector16((buf) => { + const cipherSuite = buf.readUint16() + if (cipherSuite === TLS_AES_128_GCM_SHA256) { + found = true + } + }) + if (!found) { + throw new TLSError(ALERT_DESCRIPTION.HANDSHAKE_FAILURE) + } + const legacyCompressionMethods = buf.readVectorBytes8() + if (legacyCompressionMethods.byteLength !== 1) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + if (legacyCompressionMethods[0] !== 0x00) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + const extensions = this._readExtensions(HANDSHAKE_TYPE.CLIENT_HELLO, buf) + if (!extensions.has(EXTENSION_TYPE.SUPPORTED_VERSIONS)) { + throw new TLSError(ALERT_DESCRIPTION.MISSING_EXTENSION) + } + const svExt = extensions.get(EXTENSION_TYPE.SUPPORTED_VERSIONS) as + unknown as { versions: number[] } + if (svExt.versions.indexOf(VERSION_TLS_1_3) === -1) { + throw new TLSError(ALERT_DESCRIPTION.PROTOCOL_VERSION) + } + if (extensions.has(EXTENSION_TYPE.PRE_SHARED_KEY)) { + if (extensions.lastSeenExtension !== EXTENSION_TYPE.PRE_SHARED_KEY) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + } + return new this(random, sessionId, extensions) + } + + _write(buf: BufferWriter): void { + buf.writeUint16(VERSION_TLS_1_2) + buf.writeBytes(this.random) + buf.writeVectorBytes8(this.sessionId) + buf.writeVector16((buf) => { + buf.writeUint16(TLS_AES_128_GCM_SHA256) + }) + buf.writeVectorBytes8(new Uint8Array(1)) + this._writeExtensions(buf, Array.from(this.extensions.values())) + } +} + +export class ServerHello extends HandshakeMessage { + random: Uint8Array + sessionId: Uint8Array + extensions: ExtensionMap + + constructor( + random: Uint8Array, + sessionId: Uint8Array, + extensions: ExtensionMap | ExtensionLike[] + ) { + super() + this.random = random + this.sessionId = sessionId + if (Array.isArray(extensions)) { + const map: ExtensionMap = new Map() + for (const ext of extensions) { + map.set(ext.TYPE_TAG, ext) + map.lastSeenExtension = ext.TYPE_TAG + } + this.extensions = map + } else { + this.extensions = extensions + } + } + + get TYPE_TAG(): number { + return HANDSHAKE_TYPE.SERVER_HELLO + } + + static _read(buf: BufferReader): ServerHello { + if (buf.readUint16() !== VERSION_TLS_1_2) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + const random = buf.readBytes(32) + const sessionId = buf.readVectorBytes8() + if (buf.readUint16() !== TLS_AES_128_GCM_SHA256) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + if (buf.readUint8() !== 0) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + const extensions = this._readExtensions(HANDSHAKE_TYPE.SERVER_HELLO, buf) + if (!extensions.has(EXTENSION_TYPE.SUPPORTED_VERSIONS)) { + throw new TLSError(ALERT_DESCRIPTION.MISSING_EXTENSION) + } + const svExtSH = extensions.get(EXTENSION_TYPE.SUPPORTED_VERSIONS) as + unknown as { selectedVersion: number } + if (svExtSH.selectedVersion !== VERSION_TLS_1_3) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + return new this(random, sessionId, extensions) + } + + _write(buf: BufferWriter): void { + buf.writeUint16(VERSION_TLS_1_2) + buf.writeBytes(this.random) + buf.writeVectorBytes8(this.sessionId) + buf.writeUint16(TLS_AES_128_GCM_SHA256) + buf.writeUint8(0) + this._writeExtensions(buf, Array.from(this.extensions.values())) + } +} + +export class EncryptedExtensions extends HandshakeMessage { + extensions: ExtensionMap + + constructor(extensions: ExtensionMap | ExtensionLike[]) { + super() + if (Array.isArray(extensions)) { + const map: ExtensionMap = new Map() + for (const ext of extensions) { + map.set(ext.TYPE_TAG, ext) + } + this.extensions = map + } else { + this.extensions = extensions + } + } + + get TYPE_TAG(): number { + return HANDSHAKE_TYPE.ENCRYPTED_EXTENSIONS + } + + static _read(buf: BufferReader): EncryptedExtensions { + const extensions = this._readExtensions( + HANDSHAKE_TYPE.ENCRYPTED_EXTENSIONS, + buf + ) + return new this(extensions) + } + + _write(buf: BufferWriter): void { + this._writeExtensions(buf, Array.from(this.extensions.values())) + } +} + +export class Finished extends HandshakeMessage { + verifyData: Uint8Array + + constructor(verifyData: Uint8Array) { + super() + this.verifyData = verifyData + } + + get TYPE_TAG(): number { + return HANDSHAKE_TYPE.FINISHED + } + + static _read(buf: BufferReader): Finished { + const verifyData = buf.readBytes(HASH_LENGTH) + return new this(verifyData) + } + + _write(buf: BufferWriter): void { + buf.writeBytes(this.verifyData) + } +} + +export class NewSessionTicket extends HandshakeMessage { + ticketLifetime: number + ticketAgeAdd: number + ticketNonce: Uint8Array + ticket: Uint8Array + extensions: ExtensionMap + + constructor( + ticketLifetime: number, + ticketAgeAdd: number, + ticketNonce: Uint8Array, + ticket: Uint8Array, + extensions: ExtensionMap + ) { + super() + this.ticketLifetime = ticketLifetime + this.ticketAgeAdd = ticketAgeAdd + this.ticketNonce = ticketNonce + this.ticket = ticket + this.extensions = extensions + } + + get TYPE_TAG(): number { + return HANDSHAKE_TYPE.NEW_SESSION_TICKET + } + + static _read(buf: BufferReader): NewSessionTicket { + const ticketLifetime = buf.readUint32() + const ticketAgeAdd = buf.readUint32() + const ticketNonce = buf.readVectorBytes8() + const ticket = buf.readVectorBytes16() + if (ticket.byteLength < 1) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + const extensions = this._readExtensions( + HANDSHAKE_TYPE.NEW_SESSION_TICKET, + buf + ) + return new this( + ticketLifetime, + ticketAgeAdd, + ticketNonce, + ticket, + extensions + ) + } + + _write(buf: BufferWriter): void { + buf.writeUint32(this.ticketLifetime) + buf.writeUint32(this.ticketAgeAdd) + buf.writeVectorBytes8(this.ticketNonce) + buf.writeVectorBytes16(this.ticket) + this._writeExtensions(buf, Array.from(this.extensions.values())) + } +} diff --git a/frontend/src/lib/pairing-channel/recordlayer.ts b/frontend/src/lib/pairing-channel/recordlayer.ts new file mode 100644 index 00000000..051b820a --- /dev/null +++ b/frontend/src/lib/pairing-channel/recordlayer.ts @@ -0,0 +1,274 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { VERSION_TLS_1_2, VERSION_TLS_1_0 } from "./constants" +import { BufferReader, BufferWriter, EMPTY } from "./utils" +import { ALERT_DESCRIPTION, TLSError } from "./alerts" +import { + encrypt, + decrypt, + prepareKey, + hkdfExpandLabel, + AEAD_SIZE_INFLATION, + IV_LENGTH, + KEY_LENGTH, +} from "./crypto" + +export const RECORD_TYPE = { + CHANGE_CIPHER_SPEC: 20, + ALERT: 21, + HANDSHAKE: 22, + APPLICATION_DATA: 23, +} as const + +const MAX_SEQUENCE_NUMBER = Math.pow(2, 24) +const MAX_RECORD_SIZE = Math.pow(2, 14) +const MAX_ENCRYPTED_RECORD_SIZE = MAX_RECORD_SIZE + 256 +const RECORD_HEADER_SIZE = 5 + +export class CipherState { + key: CryptoKey + iv: Uint8Array + seqnum: number + + constructor(key: CryptoKey, iv: Uint8Array) { + this.key = key + this.iv = iv + this.seqnum = 0 + } + + static async create( + baseKey: Uint8Array, + mode: "encrypt" | "decrypt" + ): Promise { + const key = await prepareKey( + await hkdfExpandLabel(baseKey, "key", EMPTY, KEY_LENGTH), + mode + ) + const iv = await hkdfExpandLabel(baseKey, "iv", EMPTY, IV_LENGTH) + return new this(key, iv) as CipherState + } + + nonce(): Uint8Array { + const nonce = this.iv.slice() + const dv = new DataView(nonce.buffer, nonce.byteLength - 4, 4) + dv.setUint32(0, dv.getUint32(0) ^ this.seqnum) + this.seqnum += 1 + if (this.seqnum > MAX_SEQUENCE_NUMBER) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + return nonce + } +} + +export class EncryptionState extends CipherState { + static async create(key: Uint8Array): Promise { + const cryptoKey = await prepareKey( + await hkdfExpandLabel(key, "key", EMPTY, KEY_LENGTH), + "encrypt" + ) + const iv = await hkdfExpandLabel(key, "iv", EMPTY, IV_LENGTH) + const state = new EncryptionState(cryptoKey, iv) + return state + } + + async encrypt( + plaintext: Uint8Array, + additionalData: Uint8Array + ): Promise { + return await encrypt(this.key, this.nonce(), plaintext, additionalData) + } +} + +export class DecryptionState extends CipherState { + static async create(key: Uint8Array): Promise { + const cryptoKey = await prepareKey( + await hkdfExpandLabel(key, "key", EMPTY, KEY_LENGTH), + "decrypt" + ) + const iv = await hkdfExpandLabel(key, "iv", EMPTY, IV_LENGTH) + const state = new DecryptionState(cryptoKey, iv) + return state + } + + async decrypt( + ciphertext: Uint8Array, + additionalData: Uint8Array + ): Promise { + return await decrypt(this.key, this.nonce(), ciphertext, additionalData) + } +} + +export class RecordLayer { + sendCallback: (data: Uint8Array) => void | Promise + _sendEncryptState: EncryptionState | null + _sendError: Error | null + _recvDecryptState: DecryptionState | null + _recvError: Error | null + _pendingRecordType: number + _pendingRecordBuf: BufferWriter | null + + constructor(sendCallback: (data: Uint8Array) => void | Promise) { + this.sendCallback = sendCallback + this._sendEncryptState = null + this._sendError = null + this._recvDecryptState = null + this._recvError = null + this._pendingRecordType = 0 + this._pendingRecordBuf = null + } + + async setSendKey(key: Uint8Array): Promise { + await this.flush() + this._sendEncryptState = await EncryptionState.create(key) + } + + async setRecvKey(key: Uint8Array): Promise { + this._recvDecryptState = await DecryptionState.create(key) + } + + async setSendError(err: Error): Promise { + this._sendError = err + } + + async setRecvError(err: Error): Promise { + this._recvError = err + } + + async send(type: number, data: Uint8Array): Promise { + if (this._sendError !== null) { + throw this._sendError + } + if (data.byteLength > MAX_RECORD_SIZE) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + if (this._pendingRecordType && this._pendingRecordType !== type) { + await this.flush() + } + if (this._pendingRecordBuf !== null) { + if (this._pendingRecordBuf.tell() + data.byteLength > MAX_RECORD_SIZE) { + await this.flush() + } + } + if (this._pendingRecordBuf === null) { + this._pendingRecordType = type + this._pendingRecordBuf = new BufferWriter() + this._pendingRecordBuf.incr(RECORD_HEADER_SIZE) + } + this._pendingRecordBuf.writeBytes(data) + } + + async flush(): Promise { + const buf = this._pendingRecordBuf + let type = this._pendingRecordType + if (!type) { + if (buf !== null) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + return + } + if (this._sendError !== null) { + throw this._sendError + } + let inflation = 0, + innerPlaintext: Uint8Array | null = null + if (this._sendEncryptState !== null) { + buf!.writeUint8(type) + innerPlaintext = buf!.slice(RECORD_HEADER_SIZE) + inflation = AEAD_SIZE_INFLATION + type = RECORD_TYPE.APPLICATION_DATA + } + const length = buf!.tell() - RECORD_HEADER_SIZE + inflation + buf!.seek(0) + buf!.writeUint8(type) + buf!.writeUint16(VERSION_TLS_1_2) + buf!.writeUint16(length) + if (this._sendEncryptState !== null) { + const additionalData = buf!.slice(0, RECORD_HEADER_SIZE) + const ciphertext = await this._sendEncryptState.encrypt( + innerPlaintext!, + additionalData + ) + buf!.writeBytes(ciphertext) + } else { + buf!.incr(length) + } + this._pendingRecordBuf = null + this._pendingRecordType = 0 + await this.sendCallback(buf!.flush()) + } + + async recv(data: Uint8Array): Promise<[number, Uint8Array]> { + if (this._recvError !== null) { + throw this._recvError + } + const buf = new BufferReader(data) + let type = buf.readUint8() + const version = buf.readUint16() + if (version !== VERSION_TLS_1_2) { + if (this._recvDecryptState !== null || version !== VERSION_TLS_1_0) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + } + const length = buf.readUint16() + let result: [number, Uint8Array] + if ( + this._recvDecryptState === null || + type === RECORD_TYPE.CHANGE_CIPHER_SPEC + ) { + result = await this._readPlaintextRecord(type, length, buf) + } else { + result = await this._readEncryptedRecord(type, length, buf) + } + if (buf.hasMoreBytes()) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + return result + } + + async _readPlaintextRecord( + type: number, + length: number, + buf: BufferReader + ): Promise<[number, Uint8Array]> { + if (length > MAX_RECORD_SIZE) { + throw new TLSError(ALERT_DESCRIPTION.RECORD_OVERFLOW) + } + return [type, buf.readBytes(length)] + } + + async _readEncryptedRecord( + type: number, + length: number, + buf: BufferReader + ): Promise<[number, Uint8Array]> { + if (length > MAX_ENCRYPTED_RECORD_SIZE) { + throw new TLSError(ALERT_DESCRIPTION.RECORD_OVERFLOW) + } + if (type !== RECORD_TYPE.APPLICATION_DATA) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + buf.incr(-RECORD_HEADER_SIZE) + const additionalData = buf.readBytes(RECORD_HEADER_SIZE) + const ciphertext = buf.readBytes(length) + const paddedPlaintext = await this._recvDecryptState!.decrypt( + ciphertext, + additionalData + ) + let i: number + for (i = paddedPlaintext.byteLength - 1; i >= 0; i--) { + if (paddedPlaintext[i] !== 0) { + break + } + } + if (i < 0) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + type = paddedPlaintext[i] + if (type === RECORD_TYPE.CHANGE_CIPHER_SPEC) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + return [type, paddedPlaintext.slice(0, i)] + } +} diff --git a/frontend/src/lib/pairing-channel/states.ts b/frontend/src/lib/pairing-channel/states.ts new file mode 100644 index 00000000..d257d23f --- /dev/null +++ b/frontend/src/lib/pairing-channel/states.ts @@ -0,0 +1,391 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { bytesAreEqual, BufferWriter, zeros } from "./utils" +import { getRandomBytes, HASH_LENGTH } from "./crypto" +import { TLSAlert, TLSCloseNotify, TLSError, ALERT_DESCRIPTION } from "./alerts" +import { + ClientHello, + ServerHello, + EncryptedExtensions, + Finished, + NewSessionTicket, + type HandshakeMessage, +} from "./messages" +import { + SupportedVersionsExtension, + PskKeyExchangeModesExtension, + PreSharedKeyExtension, + EXTENSION_TYPE, +} from "./extensions" +import { VERSION_TLS_1_3, PSK_MODE_KE } from "./constants" +import type { Connection } from "./tlsconnection" + +export class State { + conn: Connection + + constructor(conn: Connection) { + this.conn = conn + } + + async initialize(..._args: unknown[]): Promise { + // By default, nothing to do when entering the state. + } + + async sendApplicationData(_bytes: Uint8Array): Promise { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + + async recvApplicationData(_bytes: Uint8Array): Promise { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + + async recvHandshakeMessage(_msg: HandshakeMessage): Promise { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + + async recvAlertMessage(alert: TLSAlert): Promise { + switch (alert.description) { + case ALERT_DESCRIPTION.CLOSE_NOTIFY: + this.conn._closeForRecv(alert) + throw alert + default: + return await this.handleErrorAndRethrow(alert) + } + } + + async recvChangeCipherSpec(_bytes: Uint8Array): Promise { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + + async handleErrorAndRethrow(err: Error): Promise { + let alert: TLSAlert = err as TLSAlert + if (!(alert instanceof TLSAlert)) { + alert = new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + try { + await this.conn._sendAlertMessage(alert) + } catch { + // ignore + } + await this.conn._transition(ERROR, err) + throw err + } + + async close(): Promise { + const alert = new TLSCloseNotify() + await this.conn._sendAlertMessage(alert) + this.conn._closeForSend(alert) + } +} + +export class UNINITIALIZED extends State { + async initialize(): Promise { + throw new Error("uninitialized state") + } + async sendApplicationData(_bytes: Uint8Array): Promise { + throw new Error("uninitialized state") + } + async recvApplicationData(_bytes: Uint8Array): Promise { + throw new Error("uninitialized state") + } + async recvHandshakeMessage(_msg: HandshakeMessage): Promise { + throw new Error("uninitialized state") + } + async recvChangeCipherSpec(_bytes: Uint8Array): Promise { + throw new Error("uninitialized state") + } + async handleErrorAndRethrow(err: Error): Promise { + throw err + } + async close(): Promise { + throw new Error("uninitialized state") + } +} + +export class ERROR extends State { + error!: Error + + async initialize(err: Error): Promise { + this.error = err + this.conn._setConnectionFailure(err) + this.conn._recordlayer.setSendError(err) + this.conn._recordlayer.setRecvError(err) + } + async sendApplicationData(_bytes: Uint8Array): Promise { + throw this.error + } + async recvApplicationData(_bytes: Uint8Array): Promise { + throw this.error + } + async recvHandshakeMessage(_msg: HandshakeMessage): Promise { + throw this.error + } + async recvAlertMessage(_err: TLSAlert): Promise { + throw this.error + } + async recvChangeCipherSpec(_bytes: Uint8Array): Promise { + throw this.error + } + async handleErrorAndRethrow(err: Error): Promise { + throw err + } + async close(): Promise { + throw this.error + } +} + +export class CONNECTED extends State { + async initialize(): Promise { + this.conn._setConnectionSuccess() + } + async sendApplicationData(bytes: Uint8Array): Promise { + await this.conn._sendApplicationData(bytes) + } + async recvApplicationData(bytes: Uint8Array): Promise { + return bytes + } + async recvChangeCipherSpec(_bytes: Uint8Array): Promise { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } +} + +class MidHandshakeState extends State { + async recvChangeCipherSpec(bytes: Uint8Array): Promise { + if (this.conn._hasSeenChangeCipherSpec) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + if (bytes.byteLength !== 1 || bytes[0] !== 1) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + this.conn._hasSeenChangeCipherSpec = true + } +} + +export class CLIENT_START extends State { + async initialize(): Promise { + const keyschedule = this.conn._keyschedule + await keyschedule.addPSK(this.conn.psk) + const clientHello = new ClientHello( + await getRandomBytes(32), + await getRandomBytes(32), + [ + new SupportedVersionsExtension([VERSION_TLS_1_3]), + new PskKeyExchangeModesExtension([PSK_MODE_KE]), + new PreSharedKeyExtension([this.conn.pskId], [zeros(HASH_LENGTH)], null), + ] + ) + const buf = new BufferWriter() + clientHello.write(buf) + const PSK_BINDERS_SIZE = HASH_LENGTH + 1 + 2 + const truncatedTranscript = buf.slice(0, buf.tell() - PSK_BINDERS_SIZE) + const pskBinder = await keyschedule.calculateFinishedMAC( + keyschedule.extBinderKey!, + truncatedTranscript + ) + buf.incr(-HASH_LENGTH) + buf.writeBytes(pskBinder) + await this.conn._sendHandshakeMessageBytes(buf.flush()) + await this.conn._transition(CLIENT_WAIT_SH, clientHello.sessionId) + } +} + +class CLIENT_WAIT_SH extends State { + _sessionId!: Uint8Array + + async initialize(sessionId: Uint8Array): Promise { + this._sessionId = sessionId + } + async recvHandshakeMessage(msg: HandshakeMessage): Promise { + if (!(msg instanceof ServerHello)) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + if (!bytesAreEqual(msg.sessionId, this._sessionId)) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + const pskExt = msg.extensions.get(EXTENSION_TYPE.PRE_SHARED_KEY) as + | { selectedIdentity: number } + | undefined + if (!pskExt) { + throw new TLSError(ALERT_DESCRIPTION.MISSING_EXTENSION) + } + if (msg.extensions.size !== 2) { + throw new TLSError(ALERT_DESCRIPTION.UNSUPPORTED_EXTENSION) + } + if (pskExt.selectedIdentity !== 0) { + throw new TLSError(ALERT_DESCRIPTION.ILLEGAL_PARAMETER) + } + await this.conn._keyschedule.addECDHE(null) + // If we sent a non-empty sessionId, send a CCS for backward compatibility + // before switching to encrypted keys. + if (this._sessionId.byteLength > 0) { + await this.conn._sendChangeCipherSpec() + } + await this.conn._setSendKey( + this.conn._keyschedule.clientHandshakeTrafficSecret! + ) + await this.conn._setRecvKey( + this.conn._keyschedule.serverHandshakeTrafficSecret! + ) + await this.conn._transition(CLIENT_WAIT_EE) + } +} + +class CLIENT_WAIT_EE extends MidHandshakeState { + async recvHandshakeMessage(msg: HandshakeMessage): Promise { + if (!(msg instanceof EncryptedExtensions)) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + if (msg.extensions.size !== 0) { + throw new TLSError(ALERT_DESCRIPTION.UNSUPPORTED_EXTENSION) + } + const keyschedule = this.conn._keyschedule + const serverFinishedTranscript = keyschedule.getTranscript() + await this.conn._transition( + CLIENT_WAIT_FINISHED, + serverFinishedTranscript + ) + } +} + +class CLIENT_WAIT_FINISHED extends State { + _serverFinishedTranscript!: Uint8Array + + async initialize(serverFinishedTranscript: Uint8Array): Promise { + this._serverFinishedTranscript = serverFinishedTranscript + } + async recvHandshakeMessage(msg: HandshakeMessage): Promise { + if (!(msg instanceof Finished)) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + const keyschedule = this.conn._keyschedule + await keyschedule.verifyFinishedMAC( + keyschedule.serverHandshakeTrafficSecret!, + msg.verifyData, + this._serverFinishedTranscript + ) + const clientFinishedMAC = await keyschedule.calculateFinishedMAC( + keyschedule.clientHandshakeTrafficSecret! + ) + await keyschedule.finalize() + await this.conn._sendHandshakeMessage(new Finished(clientFinishedMAC)) + await this.conn._setSendKey( + keyschedule.clientApplicationTrafficSecret! + ) + await this.conn._setRecvKey( + keyschedule.serverApplicationTrafficSecret! + ) + await this.conn._transition(CLIENT_CONNECTED) + } +} + +export class CLIENT_CONNECTED extends CONNECTED { + async recvHandshakeMessage(msg: HandshakeMessage): Promise { + if (!(msg instanceof NewSessionTicket)) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + } +} + +export class SERVER_START extends State { + async recvHandshakeMessage(msg: HandshakeMessage): Promise { + if (!(msg instanceof ClientHello)) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + const pskExt = msg.extensions.get(EXTENSION_TYPE.PRE_SHARED_KEY) as + | { identities: Uint8Array[]; binders: Uint8Array[] } + | undefined + const pskModesExt = msg.extensions.get( + EXTENSION_TYPE.PSK_KEY_EXCHANGE_MODES + ) as { modes: number[] } | undefined + if (!pskExt || !pskModesExt) { + throw new TLSError(ALERT_DESCRIPTION.MISSING_EXTENSION) + } + if (pskModesExt.modes.indexOf(PSK_MODE_KE) === -1) { + throw new TLSError(ALERT_DESCRIPTION.HANDSHAKE_FAILURE) + } + const pskIndex = pskExt.identities.findIndex((pskId) => + bytesAreEqual(pskId, this.conn.pskId) + ) + if (pskIndex === -1) { + throw new TLSError(ALERT_DESCRIPTION.UNKNOWN_PSK_IDENTITY) + } + await this.conn._keyschedule.addPSK(this.conn.psk) + const keyschedule = this.conn._keyschedule + const transcript = keyschedule.getTranscript() + let pskBindersSize = 2 + for (const binder of pskExt.binders) { + pskBindersSize += binder.byteLength + 1 + } + await keyschedule.verifyFinishedMAC( + keyschedule.extBinderKey!, + pskExt.binders[pskIndex], + transcript.slice(0, -pskBindersSize) + ) + await this.conn._transition(SERVER_NEGOTIATED, msg.sessionId, pskIndex) + } +} + +class SERVER_NEGOTIATED extends MidHandshakeState { + async initialize( + sessionId: Uint8Array, + pskIndex: number + ): Promise { + await this.conn._sendHandshakeMessage( + new ServerHello(await getRandomBytes(32), sessionId, [ + new SupportedVersionsExtension(null, VERSION_TLS_1_3), + new PreSharedKeyExtension(null, null, pskIndex), + ]) + ) + if (sessionId.byteLength > 0) { + await this.conn._sendChangeCipherSpec() + } + const keyschedule = this.conn._keyschedule + await keyschedule.addECDHE(null) + await this.conn._setSendKey(keyschedule.serverHandshakeTrafficSecret!) + await this.conn._setRecvKey(keyschedule.clientHandshakeTrafficSecret!) + await this.conn._sendHandshakeMessage(new EncryptedExtensions([])) + const serverFinishedMAC = await keyschedule.calculateFinishedMAC( + keyschedule.serverHandshakeTrafficSecret! + ) + await this.conn._sendHandshakeMessage(new Finished(serverFinishedMAC)) + const clientFinishedTranscript = keyschedule.getTranscript() + const clientHandshakeTrafficSecret = + keyschedule.clientHandshakeTrafficSecret! + await keyschedule.finalize() + await this.conn._setSendKey(keyschedule.serverApplicationTrafficSecret!) + await this.conn._transition( + SERVER_WAIT_FINISHED, + clientHandshakeTrafficSecret, + clientFinishedTranscript + ) + } +} + +class SERVER_WAIT_FINISHED extends MidHandshakeState { + _clientHandshakeTrafficSecret!: Uint8Array | null + _clientFinishedTranscript!: Uint8Array | null + + async initialize( + clientHandshakeTrafficSecret: Uint8Array, + clientFinishedTranscript: Uint8Array + ): Promise { + this._clientHandshakeTrafficSecret = clientHandshakeTrafficSecret + this._clientFinishedTranscript = clientFinishedTranscript + } + async recvHandshakeMessage(msg: HandshakeMessage): Promise { + if (!(msg instanceof Finished)) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + const keyschedule = this.conn._keyschedule + await keyschedule.verifyFinishedMAC( + this._clientHandshakeTrafficSecret!, + msg.verifyData, + this._clientFinishedTranscript! + ) + this._clientHandshakeTrafficSecret = this._clientFinishedTranscript = null + await this.conn._setRecvKey(keyschedule.clientApplicationTrafficSecret!) + await this.conn._transition(CONNECTED) + } +} diff --git a/frontend/src/lib/pairing-channel/tlsconnection.ts b/frontend/src/lib/pairing-channel/tlsconnection.ts new file mode 100644 index 00000000..399909fd --- /dev/null +++ b/frontend/src/lib/pairing-channel/tlsconnection.ts @@ -0,0 +1,226 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import * as STATE from "./states" +import { assertIsBytes, noop, BufferReader } from "./utils" +import { HandshakeMessage } from "./messages" +import { KeySchedule } from "./keyschedule" +import { RecordLayer, RECORD_TYPE } from "./recordlayer" +import { TLSAlert, TLSError, ALERT_DESCRIPTION, TLSCloseNotify } from "./alerts" + +type StateClass = new (conn: Connection) => STATE.State + +export class Connection { + psk: Uint8Array + pskId: Uint8Array + connected: Promise + _onConnectionSuccess!: (() => void) | null + _onConnectionFailure!: ((err: Error) => void) | null + _state: STATE.State + _handshakeRecvBuffer: BufferReader | null + _hasSeenChangeCipherSpec: boolean + _recordlayer: RecordLayer + _keyschedule: KeySchedule + _lastPromise: Promise + + constructor( + psk: Uint8Array, + pskId: Uint8Array, + sendCallback: (data: Uint8Array) => void | Promise + ) { + this.psk = assertIsBytes(psk) + this.pskId = assertIsBytes(pskId) + this.connected = new Promise((resolve, reject) => { + this._onConnectionSuccess = resolve + this._onConnectionFailure = reject + }) + this._state = new STATE.UNINITIALIZED(this) + this._handshakeRecvBuffer = null + this._hasSeenChangeCipherSpec = false + this._recordlayer = new RecordLayer(sendCallback) + this._keyschedule = new KeySchedule() + this._lastPromise = Promise.resolve() + } + + static async create( + psk: Uint8Array, + pskId: Uint8Array, + sendCallback: (data: Uint8Array) => void | Promise + ): Promise { + return new this(psk, pskId, sendCallback) + } + + async send(data: Uint8Array): Promise { + assertIsBytes(data) + await this.connected + await this._synchronized(async () => { + await this._state.sendApplicationData(data) + }) + } + + async recv(data: Uint8Array): Promise { + assertIsBytes(data) + return await this._synchronized(async () => { + const [type, bytes] = await this._recordlayer.recv(data) + switch (type) { + case RECORD_TYPE.CHANGE_CIPHER_SPEC: + await this._state.recvChangeCipherSpec(bytes) + return null + case RECORD_TYPE.ALERT: + await this._state.recvAlertMessage(TLSAlert.fromBytes(bytes)) + return null + case RECORD_TYPE.APPLICATION_DATA: + return await this._state.recvApplicationData(bytes) + case RECORD_TYPE.HANDSHAKE: + this._handshakeRecvBuffer = new BufferReader(bytes) + if (!this._handshakeRecvBuffer.hasMoreBytes()) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + do { + this._handshakeRecvBuffer.incr(1) + const mlength = this._handshakeRecvBuffer.readUint24() + this._handshakeRecvBuffer.incr(-4) + const messageBytes = this._handshakeRecvBuffer.readBytes( + mlength + 4 + ) + this._keyschedule.addToTranscript(messageBytes) + await this._state.recvHandshakeMessage( + HandshakeMessage.fromBytes(messageBytes) + ) + } while (this._handshakeRecvBuffer.hasMoreBytes()) + this._handshakeRecvBuffer = null + return null + default: + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + }) + } + + async close(): Promise { + await this._synchronized(async () => { + await this._state.close() + }) + } + + _synchronized(cb: () => Promise): Promise { + const nextPromise = this._lastPromise + .then(() => { + return cb() + }) + .catch(async (err: Error) => { + if (err instanceof TLSCloseNotify) { + throw err + } + await this._state.handleErrorAndRethrow(err) + }) as Promise + this._lastPromise = nextPromise.then(noop, noop) + return nextPromise + } + + async _transition( + StateConstructor: StateClass, + ...args: unknown[] + ): Promise { + this._state = new StateConstructor(this) + await this._state.initialize(...args) + await this._recordlayer.flush() + } + + async _sendApplicationData(bytes: Uint8Array): Promise { + await this._recordlayer.send(RECORD_TYPE.APPLICATION_DATA, bytes) + await this._recordlayer.flush() + } + + async _sendHandshakeMessage(msg: HandshakeMessage): Promise { + await this._sendHandshakeMessageBytes(msg.toBytes()) + } + + async _sendHandshakeMessageBytes(bytes: Uint8Array): Promise { + this._keyschedule.addToTranscript(bytes) + await this._recordlayer.send(RECORD_TYPE.HANDSHAKE, bytes) + } + + async _sendAlertMessage(err: TLSAlert): Promise { + await this._recordlayer.send(RECORD_TYPE.ALERT, err.toBytes()) + await this._recordlayer.flush() + } + + async _sendChangeCipherSpec(): Promise { + await this._recordlayer.send( + RECORD_TYPE.CHANGE_CIPHER_SPEC, + new Uint8Array([0x01]) + ) + await this._recordlayer.flush() + } + + async _setSendKey(key: Uint8Array): Promise { + return await this._recordlayer.setSendKey(key) + } + + async _setRecvKey(key: Uint8Array): Promise { + if ( + this._handshakeRecvBuffer && + this._handshakeRecvBuffer.hasMoreBytes() + ) { + throw new TLSError(ALERT_DESCRIPTION.UNEXPECTED_MESSAGE) + } + return await this._recordlayer.setRecvKey(key) + } + + _setConnectionSuccess(): void { + if (this._onConnectionSuccess !== null) { + this._onConnectionSuccess() + this._onConnectionSuccess = null + this._onConnectionFailure = null + } + } + + _setConnectionFailure(err: Error): void { + if (this._onConnectionFailure !== null) { + this._onConnectionFailure(err) + this._onConnectionSuccess = null + this._onConnectionFailure = null + } + } + + _closeForSend(alert: TLSAlert): void { + this._recordlayer.setSendError(alert) + } + + _closeForRecv(alert: TLSAlert): void { + this._recordlayer.setRecvError(alert) + } +} + +export class ClientConnection extends Connection { + static async create( + psk: Uint8Array, + pskId: Uint8Array, + sendCallback: (data: Uint8Array) => void | Promise + ): Promise { + const instance = (await super.create( + psk, + pskId, + sendCallback + )) as ClientConnection + await instance._transition(STATE.CLIENT_START) + return instance + } +} + +export class ServerConnection extends Connection { + static async create( + psk: Uint8Array, + pskId: Uint8Array, + sendCallback: (data: Uint8Array) => void | Promise + ): Promise { + const instance = (await super.create( + psk, + pskId, + sendCallback + )) as ServerConnection + await instance._transition(STATE.SERVER_START) + return instance + } +} diff --git a/frontend/src/lib/pairing-channel/utils.ts b/frontend/src/lib/pairing-channel/utils.ts new file mode 100644 index 00000000..7b9e949d --- /dev/null +++ b/frontend/src/lib/pairing-channel/utils.ts @@ -0,0 +1,375 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { ALERT_DESCRIPTION, TLSError } from "./alerts" + +// +// Various low-level utility functions. +// +// These are mostly conveniences for working with Uint8Arrays as +// the primitive "bytes" type. +// + +const UTF8_ENCODER = new TextEncoder() +const UTF8_DECODER = new TextDecoder() + +export function noop(): void {} + +export function assert(cond: unknown, msg: string): asserts cond { + if (!cond) { + throw new Error("assert failed: " + msg) + } +} + +export function assertIsBytes( + value: unknown, + msg = "value must be a Uint8Array" +): Uint8Array { + // Using `value instanceof Uint8Array` seems to fail in Firefox chrome code + // for inscrutable reasons, so we do a less direct check. + assert(ArrayBuffer.isView(value), msg) + assert((value as unknown as { BYTES_PER_ELEMENT: number }).BYTES_PER_ELEMENT === 1, msg) + return value as Uint8Array +} + +export const EMPTY: Uint8Array = new Uint8Array(0) + +export function zeros(n: number): Uint8Array { + return new Uint8Array(n) +} + +export function arrayToBytes(value: number[]): Uint8Array { + return new Uint8Array(value) +} + +export function bytesToHex(bytes: Uint8Array): string { + return Array.prototype.map + .call(bytes, (byte: number) => { + let s = byte.toString(16) + if (s.length === 1) { + s = "0" + s + } + return s + }) + .join("") +} + +export function hexToBytes(hexstr: string): Uint8Array { + assert(hexstr.length % 2 === 0, "hexstr.length must be even") + const pairs: string[] = [] + for (let i = 0; i < hexstr.length; i += 2) { + pairs.push(hexstr[i] + hexstr[i + 1]) + } + return new Uint8Array(pairs.map((s) => parseInt(s, 16))) +} + +export function bytesToUtf8(bytes: Uint8Array): string { + return UTF8_DECODER.decode(bytes) +} + +export function utf8ToBytes(str: string): Uint8Array { + return UTF8_ENCODER.encode(str) +} + +export function bytesToBase64url(bytes: Uint8Array): string { + const charCodes = String.fromCharCode.apply(String, bytes as unknown as number[]) + return btoa(charCodes).replace(/\+/g, "-").replace(/\//g, "_") +} + +export function base64urlToBytes(str: string): Uint8Array { + str = atob(str.replace(/-/g, "+").replace(/_/g, "/")) + const bytes = new Uint8Array(str.length) + for (let i = 0; i < str.length; i++) { + bytes[i] = str.charCodeAt(i) + } + return bytes +} + +export function bytesAreEqual(v1: Uint8Array, v2: Uint8Array): boolean { + assertIsBytes(v1) + assertIsBytes(v2) + if (v1.length !== v2.length) { + return false + } + for (let i = 0; i < v1.length; i++) { + if (v1[i] !== v2[i]) { + return false + } + } + return true +} + +// The `BufferReader` and `BufferWriter` classes are helpers for dealing with the +// binary struct format that's used for various TLS message. + +class BufferWithPointer { + _buffer: Uint8Array + _dataview: DataView + _pos: number + + constructor(buf: Uint8Array) { + this._buffer = buf + this._dataview = new DataView(buf.buffer, buf.byteOffset, buf.byteLength) + this._pos = 0 + } + + length(): number { + return this._buffer.byteLength + } + + tell(): number { + return this._pos + } + + seek(pos: number): void { + if (pos < 0) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + if (pos > this.length()) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + this._pos = pos + } + + incr(offset: number): void { + this.seek(this._pos + offset) + } +} + +export class BufferReader extends BufferWithPointer { + hasMoreBytes(): boolean { + return this.tell() < this.length() + } + + readBytes(length: number): Uint8Array { + const start = this._buffer.byteOffset + this.tell() + this.incr(length) + return new Uint8Array(this._buffer.buffer, start, length) + } + + _rangeErrorToAlert(cb: (self: this) => T): T { + try { + return cb(this) + } catch (err) { + if (err instanceof RangeError) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + throw err + } + } + + readUint8(): number { + return this._rangeErrorToAlert(() => { + const n = this._dataview.getUint8(this._pos) + this.incr(1) + return n + }) + } + + readUint16(): number { + return this._rangeErrorToAlert(() => { + const n = this._dataview.getUint16(this._pos) + this.incr(2) + return n + }) + } + + readUint24(): number { + return this._rangeErrorToAlert(() => { + let n = this._dataview.getUint16(this._pos) + n = (n << 8) | this._dataview.getUint8(this._pos + 2) + this.incr(3) + return n + }) + } + + readUint32(): number { + return this._rangeErrorToAlert(() => { + const n = this._dataview.getUint32(this._pos) + this.incr(4) + return n + }) + } + + _readVector(length: number, cb: (buf: BufferReader, n: number) => void): void { + const contentsBuf = new BufferReader(this.readBytes(length)) + const expectedEnd = this.tell() + let n = 0 + while (contentsBuf.hasMoreBytes()) { + const prevPos = contentsBuf.tell() + cb(contentsBuf, n) + if (contentsBuf.tell() <= prevPos) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + n += 1 + } + if (this.tell() !== expectedEnd) { + throw new TLSError(ALERT_DESCRIPTION.DECODE_ERROR) + } + } + + readVector8(cb: (buf: BufferReader, n: number) => void): void { + const length = this.readUint8() + return this._readVector(length, cb) + } + + readVector16(cb: (buf: BufferReader, n: number) => void): void { + const length = this.readUint16() + return this._readVector(length, cb) + } + + readVector24(cb: (buf: BufferReader, n: number) => void): void { + const length = this.readUint24() + return this._readVector(length, cb) + } + + readVectorBytes8(): Uint8Array { + return this.readBytes(this.readUint8()) + } + + readVectorBytes16(): Uint8Array { + return this.readBytes(this.readUint16()) + } + + readVectorBytes24(): Uint8Array { + return this.readBytes(this.readUint24()) + } +} + +export class BufferWriter extends BufferWithPointer { + constructor(size = 1024) { + super(new Uint8Array(size)) + } + + _maybeGrow(n: number): void { + const curSize = this._buffer.byteLength + const newPos = this._pos + n + const shortfall = newPos - curSize + if (shortfall > 0) { + const incr = Math.min(curSize, 4 * 1024) + const newbuf = new Uint8Array(curSize + Math.ceil(shortfall / incr) * incr) + newbuf.set(this._buffer, 0) + this._buffer = newbuf + this._dataview = new DataView( + newbuf.buffer, + newbuf.byteOffset, + newbuf.byteLength + ) + } + } + + slice(start = 0, end = this.tell()): Uint8Array { + if (end < 0) { + end = this.tell() + end + } + if (start < 0) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + if (end < 0) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + if (end > this.length()) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + return this._buffer.slice(start, end) + } + + flush(): Uint8Array { + const slice = this.slice() + this.seek(0) + return slice + } + + writeBytes(data: Uint8Array): void { + this._maybeGrow(data.byteLength) + this._buffer.set(data, this.tell()) + this.incr(data.byteLength) + } + + writeUint8(n: number): void { + this._maybeGrow(1) + this._dataview.setUint8(this._pos, n) + this.incr(1) + } + + writeUint16(n: number): void { + this._maybeGrow(2) + this._dataview.setUint16(this._pos, n) + this.incr(2) + } + + writeUint24(n: number): void { + this._maybeGrow(3) + this._dataview.setUint16(this._pos, n >> 8) + this._dataview.setUint8(this._pos + 2, n & 0xff) + this.incr(3) + } + + writeUint32(n: number): void { + this._maybeGrow(4) + this._dataview.setUint32(this._pos, n) + this.incr(4) + } + + _writeVector( + maxLength: number, + writeLength: (len: number) => void, + cb: (buf: BufferWriter) => void + ): number { + const lengthPos = this.tell() + writeLength(0) + const bodyPos = this.tell() + cb(this) + const length = this.tell() - bodyPos + if (length >= maxLength) { + throw new TLSError(ALERT_DESCRIPTION.INTERNAL_ERROR) + } + this.seek(lengthPos) + writeLength(length) + this.incr(length) + return length + } + + writeVector8(cb: (buf: BufferWriter) => void): number { + return this._writeVector( + Math.pow(2, 8), + (len) => this.writeUint8(len), + cb + ) + } + + writeVector16(cb: (buf: BufferWriter) => void): number { + return this._writeVector( + Math.pow(2, 16), + (len) => this.writeUint16(len), + cb + ) + } + + writeVector24(cb: (buf: BufferWriter) => void): number { + return this._writeVector( + Math.pow(2, 24), + (len) => this.writeUint24(len), + cb + ) + } + + writeVectorBytes8(bytes: Uint8Array): number { + return this.writeVector8((buf) => { + buf.writeBytes(bytes) + }) + } + + writeVectorBytes16(bytes: Uint8Array): number { + return this.writeVector16((buf) => { + buf.writeBytes(bytes) + }) + } + + writeVectorBytes24(bytes: Uint8Array): number { + return this.writeVector24((buf) => { + buf.writeBytes(bytes) + }) + } +} diff --git a/frontend/src/lib/pairing.ts b/frontend/src/lib/pairing.ts new file mode 100644 index 00000000..31ef44de --- /dev/null +++ b/frontend/src/lib/pairing.ts @@ -0,0 +1,43 @@ +export function channelKeyToBase64url(key: Uint8Array): string { + let binary = "" + for (const byte of key) { + binary += String.fromCharCode(byte) + } + return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "") +} + +export function base64urlToChannelKey(b64: string): Uint8Array { + const str = atob(b64.replace(/-/g, "+").replace(/_/g, "/")) + const bytes = new Uint8Array(str.length) + for (let i = 0; i < str.length; i++) { + bytes[i] = str.charCodeAt(i) + } + return bytes +} + +export function buildPairUrl( + contentUrl: string, + channelId: string, + channelKey: Uint8Array +): string { + const keyB64 = channelKeyToBase64url(channelKey) + return `${contentUrl}/pair/supp#channel_id=${channelId}&channel_key=${keyB64}` +} + +export function parsePairFragment( + hash: string +): { channelId: string; channelKey: Uint8Array } | null { + const fragment = hash.startsWith("#") ? hash.slice(1) : hash + const params = new URLSearchParams(fragment) + const channelId = params.get("channel_id") + const channelKeyB64 = params.get("channel_key") + if (!channelId || !channelKeyB64) { + return null + } + try { + const channelKey = base64urlToChannelKey(channelKeyB64) + return { channelId, channelKey } + } catch { + return null + } +} diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index feb5dabb..e96d226b 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -4,6 +4,7 @@ export interface AppConfig { redirectUri: string tokenServerUrl?: string authServerUrl?: string + pairingServerUrl?: string scopes: string[] } diff --git a/frontend/src/lib/webchannel.ts b/frontend/src/lib/webchannel.ts index 90184675..5186f412 100644 --- a/frontend/src/lib/webchannel.ts +++ b/frontend/src/lib/webchannel.ts @@ -102,3 +102,11 @@ export function sendFxAStatus( messageId ) } + +export function sendPairComplete(messageId?: string): void { + sendToFirefox("fxaccounts:pair_complete", {}, messageId) +} + +export function sendPairDecline(messageId?: string): void { + sendToFirefox("fxaccounts:pair_decline", {}, messageId) +} diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 391abcbd..9d45c341 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -1,3 +1,4 @@ +/// import path from "path" import tailwindcss from "@tailwindcss/vite" import react from "@vitejs/plugin-react" @@ -10,4 +11,7 @@ export default defineConfig({ "@": path.resolve(__dirname, "./src"), }, }, + test: { + environment: "happy-dom", + }, }) diff --git a/lambda/src/entrypoint/__init__.py b/lambda/src/entrypoint/__init__.py index 7ca2f808..486cd7df 100644 --- a/lambda/src/entrypoint/__init__.py +++ b/lambda/src/entrypoint/__init__.py @@ -1,4 +1,5 @@ from .auth_api import lambda_handler as auth_api_handler +from .channel_api import lambda_handler as channel_api_handler from .profile_api import lambda_handler as profile_api_handler from .storage_api import lambda_handler as storage_api_handler from .token_api import lambda_handler as token_api_handler diff --git a/lambda/src/entrypoint/channel_api.py b/lambda/src/entrypoint/channel_api.py new file mode 100644 index 00000000..eff6ddb5 --- /dev/null +++ b/lambda/src/entrypoint/channel_api.py @@ -0,0 +1,29 @@ +"""Channel API Lambda handler for WebSocket device pairing.""" + +from typing import Optional + +from aws_lambda_powertools.utilities.typing import LambdaContext + +from src.environment.service_provider import ServiceProvider + + +def lambda_handler( + event: dict, context: LambdaContext, service_provider: Optional[ServiceProvider] = None +) -> dict: + """ + Channel Service WebSocket Lambda handler. + + Handles WebSocket $connect, $disconnect, and $default routes + for device pairing channel relay. + + Args: + event: Lambda event from API Gateway WebSocket + context: Lambda context + service_provider: Optional ServiceProvider for dependency injection + + Returns: + WebSocket response dict + """ + if service_provider is None: # pragma: nocover + service_provider = ServiceProvider() + return service_provider.channel_service.handle(event, context) diff --git a/lambda/src/environment/service_provider.py b/lambda/src/environment/service_provider.py index dc716004..bc3dd470 100644 --- a/lambda/src/environment/service_provider.py +++ b/lambda/src/environment/service_provider.py @@ -45,6 +45,7 @@ WeaveTimestampMiddleware, ) from src.services.auth_account_manager import AuthAccountManager +from src.services.channel_service import ChannelService from src.services.fxa_token_manager import FxATokenManager from src.services.hawk_service import HawkService from src.services.jwt_service import JWTService @@ -409,3 +410,19 @@ def hawk_service(self) -> HawkService: timestamp_skew_tolerance=self.hawk_timestamp_skew_tolerance, token_duration=self.token_duration, ) + + # Channel Service properties + + @cached_property + def channel_table_name(self): + return os.environ.get("CHANNEL_TABLE_NAME") + + @cached_property + def channel_table(self): + """DynamoDB Table for pairing channel state""" + resource = self.session.resource("dynamodb") + return resource.Table(self.channel_table_name) + + @cached_property + def channel_service(self) -> ChannelService: + return ChannelService(table=self.channel_table, session=self.session) diff --git a/lambda/src/routes/auth/oauth_authorization.py b/lambda/src/routes/auth/oauth_authorization.py index 78ab81ee..26dce762 100644 --- a/lambda/src/routes/auth/oauth_authorization.py +++ b/lambda/src/routes/auth/oauth_authorization.py @@ -9,6 +9,13 @@ from src.services.oauth_code_manager import OAuthCodeManager from src.shared.base_route import BaseRoute +ALLOWED_REDIRECT_URIS = { + "urn:ietf:wg:oauth:2.0:oob", + "urn:ietf:wg:oauth:2.0:oob:pair-auth-webchannel", +} + +DEFAULT_REDIRECT_URI = "urn:ietf:wg:oauth:2.0:oob" + class OAuthAuthorizationRoute(BaseRoute): """Issue an OAuth authorization code authenticated with a session token.""" @@ -50,6 +57,10 @@ def handle(self, event) -> Response: if not state: return self._error(400, 107, "Missing state") + redirect_uri = body.get("redirect_uri", DEFAULT_REDIRECT_URI) + if redirect_uri not in ALLOWED_REDIRECT_URIS: + return self._error(400, 107, "Invalid redirect_uri") + code_challenge = body.get("code_challenge", "") code_challenge_method = body.get("code_challenge_method", "S256") keys_jwe = body.get("keys_jwe", "") @@ -70,7 +81,7 @@ def handle(self, event) -> Response: { "code": code, "state": state, - "redirect": "urn:ietf:wg:oauth:2.0:oob", + "redirect": redirect_uri, } ), ) diff --git a/lambda/src/services/channel_service.py b/lambda/src/services/channel_service.py new file mode 100644 index 00000000..5ec199ba --- /dev/null +++ b/lambda/src/services/channel_service.py @@ -0,0 +1,218 @@ +"""Channel Service — WebSocket message relay for device pairing.""" + +import json +import time +import uuid + +from botocore.exceptions import ClientError + +MAX_CONNECTIONS_PER_CHANNEL = 3 +MAX_MESSAGES_PER_CHANNEL = 10 +CHANNEL_TTL_SECONDS = 300 + + +class ChannelService: + """WebSocket channel relay for device pairing. + + Uses a PK-only DynamoDB table with TTL: + - CHANNEL#{channelId} — connections (list), messageCount, expiry + - CONN#{connectionId} — channelId, expiry + """ + + def __init__(self, table, session): + self._table = table + self._session = session + self._apigw_clients = {} + + def handle(self, event, context): + """Dispatch on WebSocket route key.""" + route_key = event["requestContext"]["routeKey"] + connection_id = event["requestContext"]["connectionId"] + + if route_key == "$connect": + return self._handle_connect(event, connection_id) + elif route_key == "$disconnect": + self._handle_disconnect(connection_id) + return {"statusCode": 200} + elif route_key == "$default": + return self._handle_message(event, connection_id) + else: + return {"statusCode": 400, "body": "Unknown route"} + + def _handle_connect(self, event, connection_id): + """Handle $connect — create or join a channel.""" + params = event.get("queryStringParameters") or {} + channel_id = params.get("channelId") + expiry = int(time.time()) + CHANNEL_TTL_SECONDS + + if channel_id: + return self._join_channel(channel_id, connection_id, expiry) + else: + return self._create_channel(event, connection_id, expiry) + + def _create_channel(self, event, connection_id, expiry): + """Create a new channel with this connection as the first member.""" + channel_id = str(uuid.uuid4()) + + # Put channel metadata + self._table.put_item( + Item={ + "PK": f"CHANNEL#{channel_id}", + "connections": [connection_id], + "messageCount": 0, + "expiry": expiry, + } + ) + + # Put reverse lookup + self._table.put_item( + Item={ + "PK": f"CONN#{connection_id}", + "channelId": channel_id, + "expiry": expiry, + } + ) + + # Notify creator of channel ID + self._post_to_connection( + event, + connection_id, + json.dumps({"channelId": channel_id}), + ) + + return {"statusCode": 200} + + def _join_channel(self, channel_id, connection_id, expiry): + """Join an existing channel atomically.""" + try: + self._table.update_item( + Key={"PK": f"CHANNEL#{channel_id}"}, + UpdateExpression="SET connections = list_append(connections, :conn)", + ConditionExpression="attribute_exists(PK) AND size(connections) < :max", + ExpressionAttributeValues={ + ":conn": [connection_id], + ":max": MAX_CONNECTIONS_PER_CHANNEL, + }, + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ConditionalCheckFailedException": + # Distinguish 404 (channel doesn't exist) vs 403 (channel full) + result = self._table.get_item(Key={"PK": f"CHANNEL#{channel_id}"}) + if "Item" not in result: + return {"statusCode": 404, "body": "Channel not found"} + else: + return {"statusCode": 403, "body": "Channel full"} + raise + + # Put reverse lookup + self._table.put_item( + Item={ + "PK": f"CONN#{connection_id}", + "channelId": channel_id, + "expiry": expiry, + } + ) + + return {"statusCode": 200} + + def _handle_disconnect(self, connection_id): + """Handle disconnect — remove connection from channel.""" + # Delete reverse lookup first (idempotent guard against double-disconnect) + result = self._table.get_item(Key={"PK": f"CONN#{connection_id}"}) + if "Item" not in result: + return + + channel_id = result["Item"]["channelId"] + self._table.delete_item(Key={"PK": f"CONN#{connection_id}"}) + + # Get channel to find connection index + channel_result = self._table.get_item(Key={"PK": f"CHANNEL#{channel_id}"}) + if "Item" not in channel_result: + return + + connections = channel_result["Item"]["connections"] + if connection_id in connections: + index = connections.index(connection_id) + self._table.update_item( + Key={"PK": f"CHANNEL#{channel_id}"}, + UpdateExpression=f"REMOVE connections[{index}]", + ) + + def _handle_message(self, event, connection_id): + """Handle incoming message — relay to other connections.""" + # Look up channel for this connection + result = self._table.get_item(Key={"PK": f"CONN#{connection_id}"}) + if "Item" not in result: + return {"statusCode": 404, "body": "Connection not found"} + + channel_id = result["Item"]["channelId"] + + # Atomic message count increment with limit check + try: + self._table.update_item( + Key={"PK": f"CHANNEL#{channel_id}"}, + UpdateExpression="SET messageCount = messageCount + :one", + ConditionExpression=("attribute_exists(PK) AND messageCount < :max"), + ExpressionAttributeValues={ + ":one": 1, + ":max": MAX_MESSAGES_PER_CHANNEL, + }, + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ConditionalCheckFailedException": + # Distinguish 404 vs 429 + channel_result = self._table.get_item(Key={"PK": f"CHANNEL#{channel_id}"}) + if "Item" not in channel_result: + return {"statusCode": 404, "body": "Channel not found"} + else: + return {"statusCode": 429, "body": "Message limit reached"} + raise + + # Get connections from channel metadata + channel_result = self._table.get_item(Key={"PK": f"CHANNEL#{channel_id}"}) + if "Item" not in channel_result: + return {"statusCode": 404, "body": "Channel not found"} + + connections = channel_result["Item"]["connections"] + message_body = event.get("body", "") + + self._relay_message(event, connection_id, connections, message_body) + + return {"statusCode": 200} + + def _relay_message(self, event, sender_connection_id, connections, message_body): + """Relay message to all connections except sender.""" + data = json.dumps( + { + "sender": sender_connection_id, + "body": message_body, + } + ) + for conn_id in connections: + if conn_id != sender_connection_id: + self._post_to_connection(event, conn_id, data) + + def _post_to_connection(self, event, connection_id, data): + """Post data to a WebSocket connection via API Gateway Management API.""" + client = self._get_apigw_client(event) + try: + client.post_to_connection( + ConnectionId=connection_id, + Data=data.encode("utf-8") if isinstance(data, str) else data, + ) + except client.exceptions.GoneException: + self._handle_disconnect(connection_id) + + def _get_apigw_client(self, event): + """Lazy API Gateway Management API client, cached by endpoint.""" + domain = event["requestContext"]["domainName"] + stage = event["requestContext"]["stage"] + endpoint = f"https://{domain}/{stage}" + + if endpoint not in self._apigw_clients: + self._apigw_clients[endpoint] = self._session.client( + "apigatewaymanagementapi", + endpoint_url=endpoint, + ) + + return self._apigw_clients[endpoint] diff --git a/lambda/tests/conftest.py b/lambda/tests/conftest.py index 85a4d9b3..f64a2e31 100644 --- a/lambda/tests/conftest.py +++ b/lambda/tests/conftest.py @@ -77,6 +77,7 @@ def setup_environment( monkeypatch.setenv("TOKEN_DURATION", "300") monkeypatch.setenv("AUTH_TABLE_NAME", "test-auth-table") monkeypatch.setenv("AUTH_SIGNING_KEY_ID", "test-signing-key-id") + monkeypatch.setenv("CHANNEL_TABLE_NAME", "test-channel-table") @pytest.fixture diff --git a/lambda/tests/entrypoint/test_channel_api.py b/lambda/tests/entrypoint/test_channel_api.py new file mode 100644 index 00000000..c9724725 --- /dev/null +++ b/lambda/tests/entrypoint/test_channel_api.py @@ -0,0 +1,48 @@ +"""Tests for Channel API lambda entrypoint""" + +from unittest.mock import MagicMock + +from src.entrypoint import channel_api_handler +from src.services.channel_service import ChannelService + + +class TestChannelApiHandler: + def test_delegates_to_channel_service(self, sample_lambda_context): + """Handler delegates to channel_service.handle.""" + mock_channel = MagicMock(spec=ChannelService) + mock_channel.handle.return_value = {"statusCode": 200} + + mock_sp = MagicMock() + mock_sp.channel_service = mock_channel + + event = { + "requestContext": { + "routeKey": "$connect", + "connectionId": "conn-1", + "domainName": "ws.example.com", + "stage": "prod", + }, + } + + result = channel_api_handler(event, sample_lambda_context, mock_sp) + + assert result == {"statusCode": 200} + mock_channel.handle.assert_called_once_with(event, sample_lambda_context) + + +class TestServiceProviderChannelProperties: + """Tests for ServiceProvider channel property initialization""" + + def test_channel_table_name_from_env(self, mock_service_provider): + """Test channel_table_name reads from environment.""" + assert mock_service_provider.channel_table_name == "test-channel-table" + + def test_channel_table_creates_table_resource(self, mock_service_provider): + """Test channel_table returns a DynamoDB Table resource.""" + table = mock_service_provider.channel_table + assert table is not None + + def test_channel_service_creates_instance(self, mock_service_provider): + """Test channel_service property creates ChannelService.""" + service = mock_service_provider.channel_service + assert isinstance(service, ChannelService) diff --git a/lambda/tests/fixtures/boto.py b/lambda/tests/fixtures/boto.py index 23cc862d..b24248b7 100644 --- a/lambda/tests/fixtures/boto.py +++ b/lambda/tests/fixtures/boto.py @@ -97,6 +97,24 @@ def kms_stubber(kms_client): stubber.deactivate() +@pytest.fixture +def apigw_client(boto_session): + """API Gateway Management API client for WebSocket connection posting.""" + return boto_session.client( + "apigatewaymanagementapi", + endpoint_url="https://test.execute-api.us-east-1.amazonaws.com/prod", + ) + + +@pytest.fixture +def apigw_stubber(apigw_client): + """Botocore Stubber for API Gateway Management API.""" + stubber = Stubber(apigw_client) + stubber.activate() + yield stubber + stubber.deactivate() + + @pytest.fixture(autouse=True) def boto_session(aws_region_name, aws_access_key_id, aws_secret_access_key, aws_session_token): return boto3.session.Session( @@ -120,13 +138,15 @@ def boto_session_patch(boto_session): @pytest.fixture(autouse=True) def boto_resource_patch( - boto_session, boto_session_patch, dynamodb_client, dynamodb_resource, kms_client + boto_session, boto_session_patch, dynamodb_client, dynamodb_resource, kms_client, apigw_client ) -> Generator: def client(service, *args, **kwargs): if service == "dynamodb": return dynamodb_client if service == "kms": return kms_client + if service == "apigatewaymanagementapi": + return apigw_client raise ValueError(f"client for {service} not recognized") diff --git a/lambda/tests/routes/auth/test_oauth_authorization.py b/lambda/tests/routes/auth/test_oauth_authorization.py index 408cf7a8..ff67bf58 100644 --- a/lambda/tests/routes/auth/test_oauth_authorization.py +++ b/lambda/tests/routes/auth/test_oauth_authorization.py @@ -138,6 +138,63 @@ def test_passes_keys_jwe_to_code_manager(self, route, mock_oauth_code_manager): ) +class TestOAuthAuthorizationRedirectUri: + def test_pairing_redirect_uri_accepted(self, route, mock_oauth_code_manager): + """Pairing redirect_uri is accepted and returned in response.""" + mock_oauth_code_manager.create_authorization_code.return_value = "code-pair" + + event = _make_event( + body=json.dumps( + { + "client_id": "client1", + "scope": "openid", + "state": "st", + "redirect_uri": "urn:ietf:wg:oauth:2.0:oob:pair-auth-webchannel", + } + ), + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body["redirect"] == "urn:ietf:wg:oauth:2.0:oob:pair-auth-webchannel" + + def test_invalid_redirect_uri_returns_400(self, route, mock_oauth_code_manager): + """Invalid redirect_uri returns 400.""" + event = _make_event( + body=json.dumps( + { + "client_id": "client1", + "scope": "openid", + "state": "st", + "redirect_uri": "https://evil.example.com/callback", + } + ), + ) + response = route.handle(event) + assert response.status_code == 400 + body = json.loads(response.body) + assert body["errno"] == 107 + assert "redirect_uri" in body["message"] + + def test_default_redirect_uri_when_not_provided(self, route, mock_oauth_code_manager): + """Default redirect_uri used when none provided.""" + mock_oauth_code_manager.create_authorization_code.return_value = "code-default" + + event = _make_event( + body=json.dumps( + { + "client_id": "client1", + "scope": "openid", + "state": "st", + } + ), + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body["redirect"] == "urn:ietf:wg:oauth:2.0:oob" + + class TestOAuthAuthorizationBind: def test_bind_registers_post_route(self, route): mock_api = MagicMock() diff --git a/lambda/tests/services/test_channel_service.py b/lambda/tests/services/test_channel_service.py new file mode 100644 index 00000000..175003fd --- /dev/null +++ b/lambda/tests/services/test_channel_service.py @@ -0,0 +1,932 @@ +"""Unit tests for ChannelService with DynamoDB stubber""" + +import json +from unittest.mock import patch + +import pytest +from botocore.exceptions import ClientError + +from src.services.channel_service import ( + CHANNEL_TTL_SECONDS, + MAX_CONNECTIONS_PER_CHANNEL, + MAX_MESSAGES_PER_CHANNEL, + ChannelService, +) + +CHANNEL_TABLE_NAME = "test-channel-table" +FIXED_UUID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" +FIXED_TIME = 1700000000 + + +def _ws_event(route_key="$default", connection_id="conn-1", body=None, query_params=None): + """Build a WebSocket API Gateway event dict.""" + event = { + "requestContext": { + "routeKey": route_key, + "connectionId": connection_id, + "domainName": "ws.example.com", + "stage": "prod", + }, + "body": body, + "queryStringParameters": query_params, + } + return event + + +class TestChannelService: + """Test ChannelService DynamoDB operations""" + + @pytest.fixture + def channel_table(self, boto_session, dynamodb_stubber): + resource = boto_session.resource("dynamodb") + table = resource.Table(CHANNEL_TABLE_NAME) + table.meta.client = dynamodb_stubber.client + return table + + @pytest.fixture + def service(self, channel_table, boto_session, apigw_client, apigw_stubber): + svc = ChannelService(table=channel_table, session=boto_session) + # Pre-populate the APIGW client cache with the shared stubbed client. + # The key must match what _get_apigw_client computes from _ws_event(): + # f"https://{domainName}/{stage}" => "https://ws.example.com/prod" + svc._apigw_clients["https://ws.example.com/prod"] = apigw_client + return svc + + # -- Constants ------------------------------------------------------------ + + def test_constants(self): + assert MAX_CONNECTIONS_PER_CHANNEL == 3 + assert MAX_MESSAGES_PER_CHANNEL == 10 + assert CHANNEL_TTL_SECONDS == 300 + + # -- Create channel ------------------------------------------------------- + + @patch("src.services.channel_service.uuid.uuid4", return_value=FIXED_UUID) + @patch("src.services.channel_service.time.time", return_value=FIXED_TIME) + def test_create_channel( + self, + mock_time, + mock_uuid, + service, + dynamodb_stubber, + apigw_stubber, + ): + """Create channel stores metadata + reverse lookup + sends channelId.""" + expiry = FIXED_TIME + CHANNEL_TTL_SECONDS + + # put_item for CHANNEL# metadata + dynamodb_stubber.add_response( + "put_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Item": { + "PK": f"CHANNEL#{FIXED_UUID}", + "connections": ["conn-1"], + "messageCount": 0, + "expiry": expiry, + }, + }, + ) + + # put_item for CONN# reverse lookup + dynamodb_stubber.add_response( + "put_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Item": { + "PK": "CONN#conn-1", + "channelId": FIXED_UUID, + "expiry": expiry, + }, + }, + ) + + # post_to_connection to notify creator of channelId + apigw_stubber.add_response( + "post_to_connection", + {}, + { + "ConnectionId": "conn-1", + "Data": json.dumps({"channelId": FIXED_UUID}).encode("utf-8"), + }, + ) + + event = _ws_event(route_key="$connect", connection_id="conn-1") + result = service.handle(event, None) + + assert result == {"statusCode": 200} + apigw_stubber.assert_no_pending_responses() + + # -- Join channel --------------------------------------------------------- + + @patch("src.services.channel_service.time.time", return_value=FIXED_TIME) + def test_join_channel( + self, + mock_time, + service, + dynamodb_stubber, + ): + """Join existing channel via atomic update_item + reverse lookup.""" + expiry = FIXED_TIME + CHANNEL_TTL_SECONDS + channel_id = "existing-channel" + + # update_item for atomic join + dynamodb_stubber.add_response("update_item", {}, None) + + # put_item for CONN# reverse lookup + dynamodb_stubber.add_response( + "put_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Item": { + "PK": "CONN#conn-2", + "channelId": channel_id, + "expiry": expiry, + }, + }, + ) + + event = _ws_event( + route_key="$connect", + connection_id="conn-2", + query_params={"channelId": channel_id}, + ) + result = service.handle(event, None) + + assert result == {"statusCode": 200} + + # -- Join nonexistent channel --------------------------------------------- + + def test_join_nonexistent_channel_returns_404( + self, + service, + dynamodb_stubber, + ): + """ConditionalCheckFailed + empty get_item => 404.""" + channel_id = "no-such-channel" + + # update_item fails with ConditionalCheckFailedException + dynamodb_stubber.add_client_error( + "update_item", + service_error_code="ConditionalCheckFailedException", + service_message="Condition not met", + ) + + # get_item to distinguish 404 vs 403 => empty + dynamodb_stubber.add_response( + "get_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event( + route_key="$connect", + connection_id="conn-2", + query_params={"channelId": channel_id}, + ) + result = service.handle(event, None) + + assert result == {"statusCode": 404, "body": "Channel not found"} + + # -- Join full channel ---------------------------------------------------- + + def test_join_full_channel_returns_403( + self, + service, + dynamodb_stubber, + ): + """ConditionalCheckFailed + channel exists => 403.""" + channel_id = "full-channel" + + # update_item fails with ConditionalCheckFailedException + dynamodb_stubber.add_client_error( + "update_item", + service_error_code="ConditionalCheckFailedException", + service_message="Condition not met", + ) + + # get_item returns existing channel (full) + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "c1"}, {"S": "c2"}, {"S": "c3"}]}, + "messageCount": {"N": "0"}, + "expiry": {"N": str(FIXED_TIME + CHANNEL_TTL_SECONDS)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event( + route_key="$connect", + connection_id="conn-4", + query_params={"channelId": channel_id}, + ) + result = service.handle(event, None) + + assert result == {"statusCode": 403, "body": "Channel full"} + + # -- Disconnect with cleanup ---------------------------------------------- + + def test_disconnect_cleans_up( + self, + service, + dynamodb_stubber, + ): + """Disconnect removes reverse lookup then patches connections list.""" + channel_id = "chan-1" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # delete_item for CONN# + dynamodb_stubber.add_response( + "delete_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # get_item for CHANNEL# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-1"}, {"S": "conn-2"}]}, + "messageCount": {"N": "0"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + # update_item to REMOVE connections[0] + dynamodb_stubber.add_response("update_item", {}, None) + + event = _ws_event(route_key="$disconnect", connection_id="conn-1") + result = service.handle(event, None) + + assert result == {"statusCode": 200} + + # -- Disconnect unknown connection ---------------------------------------- + + def test_disconnect_unknown_connection( + self, + service, + dynamodb_stubber, + ): + """Disconnect with no reverse lookup => no-op.""" + # get_item for CONN# => empty + dynamodb_stubber.add_response( + "get_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#unknown"}, + }, + ) + + event = _ws_event(route_key="$disconnect", connection_id="unknown") + result = service.handle(event, None) + + assert result == {"statusCode": 200} + + # -- Relay message -------------------------------------------------------- + + def test_relay_message( + self, + service, + dynamodb_stubber, + apigw_stubber, + ): + """Message relayed to other connections in the channel.""" + channel_id = "chan-1" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item for atomic message count increment + dynamodb_stubber.add_response("update_item", {}, None) + + # get_item for CHANNEL# (connections) + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-1"}, {"S": "conn-2"}]}, + "messageCount": {"N": "1"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + # post_to_connection to relay message to conn-2 + apigw_stubber.add_response( + "post_to_connection", + {}, + { + "ConnectionId": "conn-2", + "Data": json.dumps({"sender": "conn-1", "body": "hello"}).encode("utf-8"), + }, + ) + + event = _ws_event( + route_key="$default", + connection_id="conn-1", + body="hello", + ) + result = service.handle(event, None) + + assert result == {"statusCode": 200} + apigw_stubber.assert_no_pending_responses() + + # -- Unknown connection message ------------------------------------------- + + def test_unknown_connection_message_returns_404( + self, + service, + dynamodb_stubber, + ): + """Message from unknown connection => 404.""" + # get_item for CONN# => empty + dynamodb_stubber.add_response( + "get_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-x"}, + }, + ) + + event = _ws_event(route_key="$default", connection_id="conn-x", body="hi") + result = service.handle(event, None) + + assert result == {"statusCode": 404, "body": "Connection not found"} + + # -- Channel not found on message ----------------------------------------- + + def test_channel_not_found_on_message_returns_404( + self, + service, + dynamodb_stubber, + ): + """Message with valid connection but missing channel => 404.""" + channel_id = "gone-channel" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item fails (channel deleted) + dynamodb_stubber.add_client_error( + "update_item", + service_error_code="ConditionalCheckFailedException", + service_message="Condition not met", + ) + + # get_item for CHANNEL# => empty (channel gone) + dynamodb_stubber.add_response( + "get_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event(route_key="$default", connection_id="conn-1", body="hi") + result = service.handle(event, None) + + assert result == {"statusCode": 404, "body": "Channel not found"} + + # -- Message limit -------------------------------------------------------- + + def test_message_limit_returns_429( + self, + service, + dynamodb_stubber, + ): + """Message count at limit => 429.""" + channel_id = "busy-channel" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item fails (message count at limit) + dynamodb_stubber.add_client_error( + "update_item", + service_error_code="ConditionalCheckFailedException", + service_message="Condition not met", + ) + + # get_item for CHANNEL# => exists (so it's a 429, not 404) + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-1"}]}, + "messageCount": {"N": "10"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event(route_key="$default", connection_id="conn-1", body="hi") + result = service.handle(event, None) + + assert result == {"statusCode": 429, "body": "Message limit reached"} + + # -- GoneException triggers cleanup --------------------------------------- + + def test_gone_exception_triggers_cleanup( + self, + service, + dynamodb_stubber, + apigw_stubber, + ): + """When post_to_connection raises GoneException, stale conn is cleaned up.""" + channel_id = "chan-1" + + # get_item for CONN# (sender) + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item for message count + dynamodb_stubber.add_response("update_item", {}, None) + + # get_item for CHANNEL# (connections) + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-1"}, {"S": "conn-stale"}]}, + "messageCount": {"N": "1"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + # post_to_connection raises GoneException for stale conn + apigw_stubber.add_client_error( + "post_to_connection", + service_error_code="GoneException", + service_message="Connection gone", + expected_params={ + "ConnectionId": "conn-stale", + "Data": json.dumps({"sender": "conn-1", "body": "ping"}).encode("utf-8"), + }, + ) + + # _handle_disconnect cleanup stubs for the stale connection: + # get_item for CONN#conn-stale + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-stale"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-stale"}, + }, + ) + + # delete_item for CONN#conn-stale + dynamodb_stubber.add_response( + "delete_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-stale"}, + }, + ) + + # get_item for CHANNEL# to find index + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-1"}, {"S": "conn-stale"}]}, + "messageCount": {"N": "1"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + # update_item to REMOVE connections[1] + dynamodb_stubber.add_response("update_item", {}, None) + + event = _ws_event(route_key="$default", connection_id="conn-1", body="ping") + result = service.handle(event, None) + + assert result == {"statusCode": 200} + apigw_stubber.assert_no_pending_responses() + + # -- Empty body relay ----------------------------------------------------- + + def test_empty_body_relay( + self, + service, + dynamodb_stubber, + apigw_stubber, + ): + """Relay works when body is missing from event.""" + channel_id = "chan-1" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item for message count + dynamodb_stubber.add_response("update_item", {}, None) + + # get_item for CHANNEL# (connections) + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-1"}, {"S": "conn-2"}]}, + "messageCount": {"N": "1"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + # post_to_connection to relay message to conn-2 (empty body) + apigw_stubber.add_response( + "post_to_connection", + {}, + { + "ConnectionId": "conn-2", + "Data": json.dumps({"sender": "conn-1", "body": ""}).encode("utf-8"), + }, + ) + + # Event with no body key + event = _ws_event(route_key="$default", connection_id="conn-1") + del event["body"] + result = service.handle(event, None) + + assert result == {"statusCode": 200} + apigw_stubber.assert_no_pending_responses() + + # -- Lazy APIGW client init ----------------------------------------------- + + def test_lazy_apigw_client_init( + self, + channel_table, + boto_session, + ): + """Client is created lazily on first use and cached by endpoint.""" + svc = ChannelService(table=channel_table, session=boto_session) + assert svc._apigw_clients == {} + + event = _ws_event() + client1 = svc._get_apigw_client(event) + assert "https://ws.example.com/prod" in svc._apigw_clients + assert client1 is svc._apigw_clients["https://ws.example.com/prod"] + + # Second call returns the same cached client + client2 = svc._get_apigw_client(event) + assert client1 is client2 + + # -- Disconnect: channel gone after CONN delete ---------------------------- + + def test_disconnect_channel_gone_after_conn_delete( + self, + service, + dynamodb_stubber, + ): + """Disconnect when channel disappears between CONN delete and channel lookup.""" + channel_id = "vanished-chan" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # delete_item for CONN# + dynamodb_stubber.add_response( + "delete_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # get_item for CHANNEL# => empty (channel gone) + dynamodb_stubber.add_response( + "get_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event(route_key="$disconnect", connection_id="conn-1") + result = service.handle(event, None) + + assert result == {"statusCode": 200} + + # -- Disconnect: connection not in connections list ----------------------- + + def test_disconnect_connection_not_in_list( + self, + service, + dynamodb_stubber, + ): + """Disconnect when connection is not in the channel's connections list.""" + channel_id = "chan-1" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # delete_item for CONN# + dynamodb_stubber.add_response( + "delete_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # get_item for CHANNEL# => connection already removed from list + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"CHANNEL#{channel_id}"}, + "connections": {"L": [{"S": "conn-other"}]}, + "messageCount": {"N": "0"}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event(route_key="$disconnect", connection_id="conn-1") + result = service.handle(event, None) + + assert result == {"statusCode": 200} + + # -- Join: unexpected ClientError re-raised ------------------------------ + + def test_join_unexpected_client_error_reraised( + self, + service, + dynamodb_stubber, + ): + """Non-ConditionalCheckFailed ClientError is re-raised on join.""" + dynamodb_stubber.add_client_error( + "update_item", + service_error_code="InternalServerError", + service_message="Unexpected error", + ) + + event = _ws_event( + route_key="$connect", + connection_id="conn-2", + query_params={"channelId": "some-channel"}, + ) + with pytest.raises(ClientError): + service.handle(event, None) + + # -- Message: unexpected ClientError re-raised --------------------------- + + def test_message_unexpected_client_error_reraised( + self, + service, + dynamodb_stubber, + ): + """Non-ConditionalCheckFailed ClientError is re-raised on message.""" + channel_id = "chan-1" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item fails with unexpected error + dynamodb_stubber.add_client_error( + "update_item", + service_error_code="InternalServerError", + service_message="Unexpected error", + ) + + event = _ws_event(route_key="$default", connection_id="conn-1", body="hi") + with pytest.raises(ClientError): + service.handle(event, None) + + # -- Message: channel gone between count update and connections lookup ---- + + def test_message_channel_gone_after_count_update( + self, + service, + dynamodb_stubber, + ): + """Channel disappears between message count update and connections fetch.""" + channel_id = "ephemeral-chan" + + # get_item for CONN# + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": "CONN#conn-1"}, + "channelId": {"S": channel_id}, + "expiry": {"N": str(FIXED_TIME)}, + } + }, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": "CONN#conn-1"}, + }, + ) + + # update_item for message count succeeds + dynamodb_stubber.add_response("update_item", {}, None) + + # get_item for CHANNEL# => empty (TTL expired between calls) + dynamodb_stubber.add_response( + "get_item", + {}, + { + "TableName": CHANNEL_TABLE_NAME, + "Key": {"PK": f"CHANNEL#{channel_id}"}, + }, + ) + + event = _ws_event(route_key="$default", connection_id="conn-1", body="hi") + result = service.handle(event, None) + + assert result == {"statusCode": 404, "body": "Channel not found"} + + # -- Unknown route -------------------------------------------------------- + + def test_unknown_route_returns_400(self, service): + """Unknown route key returns 400.""" + event = _ws_event(route_key="$unknown") + result = service.handle(event, None) + + assert result == {"statusCode": 400, "body": "Unknown route"} diff --git a/lib/app.ts b/lib/app.ts index 36f1a659..f81a7725 100644 --- a/lib/app.ts +++ b/lib/app.ts @@ -29,6 +29,7 @@ new GitHubOidcStack(app, "GitHubOidcStack", { authApiDomain: serviceStack.authApiDomain, tokenApiDomain: serviceStack.tokenApiDomain, profileApiDomain: serviceStack.profileApiDomain, + channelApiDomain: serviceStack.channelApiDomain, oidcProviderUrl: serviceStack.oidcProviderUrlParam, clientId: serviceStack.clientIdParam, }); diff --git a/lib/config/service.ts b/lib/config/service.ts index 80131d25..b7ba9f92 100644 --- a/lib/config/service.ts +++ b/lib/config/service.ts @@ -3,4 +3,5 @@ export enum Service { TOKEN = "token", PROFILE = "profile", STORAGE = "storage", + CHANNEL = "channel", } diff --git a/lib/stacks/frontend.ts b/lib/stacks/frontend.ts index 0c11e6ff..c14fd448 100644 --- a/lib/stacks/frontend.ts +++ b/lib/stacks/frontend.ts @@ -26,6 +26,7 @@ export interface FrontendStackProps extends StackProps { authApiDomain: string; tokenApiDomain: string; profileApiDomain: string; + channelApiDomain: string; oidcProviderUrl: IStringParameter; clientId: IStringParameter; } @@ -86,6 +87,7 @@ export class FrontendStack extends Stack { profile_server_base_url: `https://${this.props.profileApiDomain}`, sync_tokenserver_base_url: `https://${this.props.tokenApiDomain}`, content_url: `https://${this.domainName}`, + pairing_server_base_uri: `wss://${this.props.channelApiDomain}`, }); return new CfFunction(this, "WellKnownFunction", { @@ -158,6 +160,7 @@ export class FrontendStack extends Stack { redirectUri: `https://${this.domainName}`, authServerUrl: `https://${this.props.authApiDomain}`, scopes: ["openid", "profile", "email"], + pairingServerUrl: `wss://${this.props.channelApiDomain}`, }), ], destinationBucket: this.bucket, diff --git a/lib/stacks/service.ts b/lib/stacks/service.ts index 0246653d..8b2ae34d 100644 --- a/lib/stacks/service.ts +++ b/lib/stacks/service.ts @@ -13,6 +13,13 @@ import { SecurityPolicy, SpecRestApi, } from "aws-cdk-lib/aws-apigateway"; +import { + DomainName as ApiGwV2DomainName, + ApiMapping, + WebSocketApi, + WebSocketStage, +} from "aws-cdk-lib/aws-apigatewayv2"; +import {WebSocketLambdaIntegration} from "aws-cdk-lib/aws-apigatewayv2-integrations"; import {Certificate, CertificateValidation} from "aws-cdk-lib/aws-certificatemanager"; import { AttributeType, @@ -32,7 +39,7 @@ import { RecordTarget, RecordType, } from "aws-cdk-lib/aws-route53"; -import {ApiGateway} from "aws-cdk-lib/aws-route53-targets"; +import {ApiGateway, ApiGatewayv2DomainProperties} from "aws-cdk-lib/aws-route53-targets"; import {IStringParameter, StringParameter} from "aws-cdk-lib/aws-ssm"; import {BASE_DOMAIN, HOSTED_ZONE_ID, StageType} from "../config"; @@ -68,6 +75,10 @@ export class ServiceStack extends Stack { return `${Service.PROFILE}.${this.props.stageType}.${BASE_DOMAIN}`; } + public get channelApiDomain(): string { + return `${Service.CHANNEL}.${this.props.stageType}.${BASE_DOMAIN}`; + } + // Auth Service public readonly tokenUsersTable: Table; public readonly tokenCacheTable: Table; @@ -89,6 +100,10 @@ export class ServiceStack extends Stack { public readonly storageHandler: IFunction; public readonly storageApi: SpecRestApi; + // Channel Service + public readonly channelTable: Table; + public readonly channelHandler: PythonFunction; + constructor(scope: Construct, id: string, props: ServiceStackProps) { super(scope, id, props); @@ -127,6 +142,11 @@ export class ServiceStack extends Stack { this.tokenApi = this.buildApi(Service.TOKEN, this.tokenHandler); this.profileApi = this.buildApi(Service.PROFILE, this.profileHandler); this.storageApi = this.buildApi(Service.STORAGE, this.storageHandler); + + // Channel Service + this.channelTable = this.buildChannelTable(); + this.channelHandler = this.buildChannelApiHandler(); + this.buildChannelWebSocketApi(); } private buildStorageTable(): Table { @@ -353,6 +373,92 @@ export class ServiceStack extends Stack { return fn; } + private buildChannelTable(): Table { + return new Table(this, "ChannelTable", { + tableName: `ffsync-channel-${this.props.stageType.toLowerCase()}`, + partitionKey: {name: "PK", type: AttributeType.STRING}, + billingMode: BillingMode.PAY_PER_REQUEST, + encryption: TableEncryption.AWS_MANAGED, + timeToLiveAttribute: "expiry", + removalPolicy: RemovalPolicy.DESTROY, + }); + } + + private buildChannelApiHandler(): PythonFunction { + const fn = new PythonFunction(this, "ChannelApiHandler", { + rootDir: path.join(__dirname, "../../lambda"), + index: "src/entrypoint/__init__.py", + runtime: Runtime.PYTHON_3_14, + architecture: Architecture.ARM_64, + handler: "channel_api_handler", + functionName: `ffsync-channel-api-${this.props.stageType.toLowerCase()}`, + timeout: Duration.seconds(10), + memorySize: 256, + environment: { + STAGE: this.props.stageType.toLowerCase(), + CHANNEL_TABLE_NAME: this.channelTable.tableName, + }, + bundling: { + assetExcludes: [".venv/", ".git/", "tests/", "htmlcov/", ".pytest_cache/", ".mypy_cache/"], + }, + }); + + this.channelTable.grantReadWriteData(fn); + + return fn; + } + + private buildChannelWebSocketApi(): void { + const stage = this.props.stageType.toLowerCase(); + const integration = new WebSocketLambdaIntegration("ChannelIntegration", this.channelHandler); + + const wsApi = new WebSocketApi(this, "ChannelWebSocketApi", { + apiName: `ffsync-channel-${stage}`, + connectRouteOptions: {integration}, + disconnectRouteOptions: {integration}, + defaultRouteOptions: {integration}, + }); + + const wsStage = new WebSocketStage(this, "ChannelWebSocketStage", { + webSocketApi: wsApi, + stageName: stage, + autoDeploy: true, + }); + + const domainName = this.channelApiDomain; + const certificate = new Certificate(this, "ChannelCertificate", { + domainName, + validation: CertificateValidation.fromDns(this.hostedZone), + }); + + const apiDomainName = new ApiGwV2DomainName(this, "ChannelDomainName", { + domainName, + certificate, + }); + + new ApiMapping(this, "ChannelApiMapping", { + api: wsApi, + domainName: apiDomainName, + stage: wsStage, + }); + + [RecordType.A, RecordType.AAAA].map((recordType) => { + new RecordSet(this, `Channel${recordType}RecordSet`, { + recordType, + zone: this.hostedZone, + recordName: domainName, + target: RecordTarget.fromAlias( + new ApiGatewayv2DomainProperties( + apiDomainName.regionalDomainName, + apiDomainName.regionalHostedZoneId, + ), + ), + }); + }); + + wsApi.grantManageConnections(this.channelHandler); + } + private buildApiExecuteRole(): Role { return new Role(this, "ApiRole", { roleName: `ffsync-api-role-${this.props.stageType.toLowerCase()}`, diff --git a/test/frontend.test.ts b/test/frontend.test.ts index d7d086d8..4ab1c39d 100644 --- a/test/frontend.test.ts +++ b/test/frontend.test.ts @@ -32,6 +32,7 @@ describe("FrontendStack", () => { authApiDomain: "api.example.com", tokenApiDomain: "token.example.com", profileApiDomain: "profile.example.com", + channelApiDomain: "channel.example.com", oidcProviderUrl: StringParameter.fromStringParameterName(helperStack, "OidcParam", "/test/oidc-url"), clientId: StringParameter.fromStringParameterName(helperStack, "ClientIdParam", "/test/client-id"), });