diff --git a/lambda/src/environment/service_provider.py b/lambda/src/environment/service_provider.py index bc3dd470..36c3249a 100644 --- a/lambda/src/environment/service_provider.py +++ b/lambda/src/environment/service_provider.py @@ -6,7 +6,11 @@ import boto3 from aws_lambda_powertools.event_handler import CORSConfig, Response +from src.routes.auth.account_attached_clients import AccountAttachedClientsRoute from src.routes.auth.account_create import AccountCreateRoute +from src.routes.auth.account_device import AccountDeviceRoute +from src.routes.auth.account_devices import AccountDevicesRoute +from src.routes.auth.account_devices_notify import AccountDevicesNotifyRoute from src.routes.auth.account_keys import AccountKeysRoute from src.routes.auth.account_login import AccountLoginRoute from src.routes.auth.account_status import AccountStatusRoute @@ -46,6 +50,7 @@ ) from src.services.auth_account_manager import AuthAccountManager from src.services.channel_service import ChannelService +from src.services.device_manager import DeviceManager from src.services.fxa_token_manager import FxATokenManager from src.services.hawk_service import HawkService from src.services.jwt_service import JWTService @@ -268,6 +273,10 @@ def kms_client(self): # pragma: nocover def auth_account_manager(self) -> AuthAccountManager: return AuthAccountManager(table=self.auth_table) + @cached_property + def device_manager(self) -> DeviceManager: + return DeviceManager(table=self.auth_table) + @cached_property def fxa_token_manager(self) -> FxATokenManager: return FxATokenManager(table=self.auth_table) @@ -351,6 +360,22 @@ def auth_api_router(self): oidc_validator=self.oidc_validator, account_manager=self.auth_account_manager, ), + # Device management routes + AccountDeviceRoute( + device_manager=self.device_manager, + middlewares=[self.session_hawk_middleware], + ), + AccountDevicesRoute( + device_manager=self.device_manager, + middlewares=[self.session_hawk_middleware], + ), + AccountAttachedClientsRoute( + device_manager=self.device_manager, + middlewares=[self.session_hawk_middleware], + ), + AccountDevicesNotifyRoute( + middlewares=[self.session_hawk_middleware], + ), ], middlewares=[WeaveTimestampMiddleware()], cors=self.cors_config, diff --git a/lambda/src/middlewares/hawk_auth.py b/lambda/src/middlewares/hawk_auth.py index 13d3cfc5..b077d93e 100644 --- a/lambda/src/middlewares/hawk_auth.py +++ b/lambda/src/middlewares/hawk_auth.py @@ -4,7 +4,7 @@ authentication. Pass hawk_service for storage API, token_manager for auth API session-authenticated routes. -On success, injects ``hawk_uid`` into event["requestContext"]. +On success, injects ``hawk_uid`` and ``hawk_token_id`` into event["requestContext"]. On failure, raises HawkAuthenticationError (handle with router exception handler). """ @@ -85,3 +85,7 @@ def _validate_session_hawk(self, event, auth_header, method, path, host, port): raise HawkAuthenticationError("Invalid or expired session token") event["requestContext"]["hawk_uid"] = uid + # Inject the Hawk id (session token ID) for device correlation + event["requestContext"]["hawk_token_id"] = ( + FxATokenManager.extract_token_id_from_hawk_header(auth_header) or "" + ) diff --git a/lambda/src/routes/auth/account_attached_clients.py b/lambda/src/routes/auth/account_attached_clients.py new file mode 100644 index 00000000..8aab5e7d --- /dev/null +++ b/lambda/src/routes/auth/account_attached_clients.py @@ -0,0 +1,58 @@ +"""AccountAttachedClients route — GET /v1/account/attached_clients""" + +import json +from typing import Sequence + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler + +from src.services.device_manager import DeviceManager +from src.shared.base_route import BaseRoute + + +class AccountAttachedClientsRoute(BaseRoute): + """Return attached clients derived from device records.""" + + def __init__( + self, + device_manager: DeviceManager, + middlewares: Sequence[BaseMiddlewareHandler] = (), + ): + self._device_manager = device_manager + self.middlewares = middlewares + + def bind(self, app: APIGatewayRestResolver): + @app.get("/v1/account/attached_clients", middlewares=list(self.middlewares)) + def handle_account_attached_clients(): + return self.handle(app.current_event) + + def handle(self, event) -> Response: + uid = event["requestContext"]["hawk_uid"] + session_token_id = event["requestContext"].get("hawk_token_id", "") + + devices = self._device_manager.get_devices(uid) + clients = [] + for d in devices: + clients.append( + { + "clientId": None, + "deviceId": d.get("id"), + "sessionTokenId": d.get("sessionTokenId"), + "refreshTokenId": None, + "isCurrentSession": d.get("sessionTokenId") == session_token_id, + "deviceType": d.get("type"), + "name": d.get("name"), + "createdTime": d.get("createdAt"), + "lastAccessTime": d.get("lastAccessTime"), + "scope": None, + "location": {}, + "userAgent": "", + "os": None, + } + ) + + return Response( + status_code=200, + content_type="application/json", + body=json.dumps(clients), + ) diff --git a/lambda/src/routes/auth/account_device.py b/lambda/src/routes/auth/account_device.py new file mode 100644 index 00000000..a2fd0dd0 --- /dev/null +++ b/lambda/src/routes/auth/account_device.py @@ -0,0 +1,39 @@ +"""AccountDevice route — POST /v1/account/device""" + +import json +from typing import Sequence + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler + +from src.services.device_manager import DeviceManager +from src.shared.base_route import BaseRoute + + +class AccountDeviceRoute(BaseRoute): + """Create or update a device registration.""" + + def __init__( + self, + device_manager: DeviceManager, + middlewares: Sequence[BaseMiddlewareHandler] = (), + ): + self._device_manager = device_manager + self.middlewares = middlewares + + def bind(self, app: APIGatewayRestResolver): + @app.post("/v1/account/device", middlewares=list(self.middlewares)) + def handle_account_device(): + return self.handle(app.current_event) + + def handle(self, event) -> Response: + uid = event["requestContext"]["hawk_uid"] + session_token_id = event["requestContext"].get("hawk_token_id", "") + body = json.loads(event.body or "{}") + device = self._device_manager.upsert_device(uid, session_token_id, body) + + return Response( + status_code=200, + content_type="application/json", + body=json.dumps(device), + ) diff --git a/lambda/src/routes/auth/account_devices.py b/lambda/src/routes/auth/account_devices.py new file mode 100644 index 00000000..3cc797f7 --- /dev/null +++ b/lambda/src/routes/auth/account_devices.py @@ -0,0 +1,45 @@ +"""AccountDevices route — GET /v1/account/devices""" + +import json +from typing import Sequence + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler + +from src.services.device_manager import DeviceManager +from src.shared.base_route import BaseRoute + + +class AccountDevicesRoute(BaseRoute): + """List all devices for the authenticated user.""" + + def __init__( + self, + device_manager: DeviceManager, + middlewares: Sequence[BaseMiddlewareHandler] = (), + ): + self._device_manager = device_manager + self.middlewares = middlewares + + def bind(self, app: APIGatewayRestResolver): + @app.get("/v1/account/devices", middlewares=list(self.middlewares)) + def handle_account_devices(): + return self.handle(app.current_event) + + def handle(self, event) -> Response: + uid = event["requestContext"]["hawk_uid"] + session_token_id = event["requestContext"].get("hawk_token_id", "") + + params = event.query_string_parameters or {} + filter_ts = params.get("filterIdleDevicesTimestamp") + filter_idle = int(filter_ts) if filter_ts else None + + devices = self._device_manager.get_devices(uid, filter_idle) + for d in devices: + d["isCurrentDevice"] = d.get("sessionTokenId") == session_token_id + + return Response( + status_code=200, + content_type="application/json", + body=json.dumps(devices), + ) diff --git a/lambda/src/routes/auth/account_devices_notify.py b/lambda/src/routes/auth/account_devices_notify.py new file mode 100644 index 00000000..7a37abd4 --- /dev/null +++ b/lambda/src/routes/auth/account_devices_notify.py @@ -0,0 +1,28 @@ +"""AccountDevicesNotify route — POST /v1/account/devices/notify (no-op)""" + +import json +from typing import Sequence + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler + +from src.shared.base_route import BaseRoute + + +class AccountDevicesNotifyRoute(BaseRoute): + """No-op push notification endpoint for device-to-device messaging.""" + + def __init__(self, middlewares: Sequence[BaseMiddlewareHandler] = ()): + self.middlewares = middlewares + + def bind(self, app: APIGatewayRestResolver): + @app.post("/v1/account/devices/notify", middlewares=list(self.middlewares)) + def handle_account_devices_notify(): + return self.handle(app.current_event) + + def handle(self, event) -> Response: + return Response( + status_code=200, + content_type="application/json", + body=json.dumps({}), + ) diff --git a/lambda/src/services/device_manager.py b/lambda/src/services/device_manager.py new file mode 100644 index 00000000..cb0369ab --- /dev/null +++ b/lambda/src/services/device_manager.py @@ -0,0 +1,68 @@ +"""Device management for FxA device registration and listing.""" + +import re +import time +import uuid +from typing import Optional + +from boto3.dynamodb.conditions import Attr + +DEVICE_PREFIX = "DEVICE" +_HAWK_ID_PATTERN = re.compile(r'id="([^"]+)"') + + +class DeviceManager: + """Manages FxA device records in DynamoDB.""" + + def __init__(self, table): + self.table = table + + def _device_pk(self, uid: str, device_id: str) -> str: + return f"{DEVICE_PREFIX}#{uid}#{device_id}" + + def upsert_device(self, uid: str, session_token_id: str, data: dict) -> dict: + """Create or update a device record.""" + now = int(time.time() * 1000) + device_id = data.get("id") + + if device_id: + # Update: get existing, merge new fields + response = self.table.get_item(Key={"PK": self._device_pk(uid, device_id)}) + existing = response.get("Item", {}) + existing.pop("PK", None) + # Merge: new data overwrites existing, but preserve createdAt + device = {**existing, **{k: v for k, v in data.items() if v is not None}} + device["lastAccessTime"] = now + device["sessionTokenId"] = session_token_id + else: + # Create new device + device_id = uuid.uuid4().hex + device = { + "id": device_id, + "name": data.get("name", ""), + "type": data.get("type", "desktop"), + "pushCallback": data.get("pushCallback"), + "pushPublicKey": data.get("pushPublicKey"), + "pushAuthKey": data.get("pushAuthKey"), + "pushEndpointExpired": False, + "availableCommands": data.get("availableCommands", {}), + "sessionTokenId": session_token_id, + "createdAt": now, + "lastAccessTime": now, + } + + self.table.put_item(Item={"PK": self._device_pk(uid, device_id), **device}) + return device + + def get_devices(self, uid: str, filter_idle_timestamp: Optional[int] = None) -> list[dict]: + """List all devices for a user.""" + response = self.table.scan( + FilterExpression=Attr("PK").begins_with(f"{DEVICE_PREFIX}#{uid}#") + ) + devices = [] + for item in response.get("Items", []): + item.pop("PK", None) + if filter_idle_timestamp and item.get("lastAccessTime", 0) < filter_idle_timestamp: + continue + devices.append(item) + return devices diff --git a/lambda/src/services/fxa_token_manager.py b/lambda/src/services/fxa_token_manager.py index 04d254bd..dea2cf27 100644 --- a/lambda/src/services/fxa_token_manager.py +++ b/lambda/src/services/fxa_token_manager.py @@ -1,5 +1,6 @@ """FxA Token Manager for session tokens and key-fetch tokens in DynamoDB""" +import re import time from typing import Optional @@ -19,10 +20,20 @@ SESSION_PREFIX = "SESSION" KEYFETCH_PREFIX = "KEYFETCH" +_HAWK_ID_PATTERN = re.compile(r'id="([^"]+)"') + class FxATokenManager: """Manages FxA session tokens and key-fetch tokens in DynamoDB""" + @staticmethod + def extract_token_id_from_hawk_header(authorization_header: str) -> str | None: + """Extract the Hawk id field (session token ID) from an Authorization header.""" + if not authorization_header: + return None + match = _HAWK_ID_PATTERN.search(authorization_header) + return match.group(1) if match else None + def __init__( self, table, diff --git a/lambda/tests/routes/auth/test_account_attached_clients.py b/lambda/tests/routes/auth/test_account_attached_clients.py new file mode 100644 index 00000000..3db9ea28 --- /dev/null +++ b/lambda/tests/routes/auth/test_account_attached_clients.py @@ -0,0 +1,119 @@ +"""Unit tests for AccountAttachedClients route""" + +import json +from unittest.mock import MagicMock + +import pytest +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent + +from src.routes.auth.account_attached_clients import AccountAttachedClientsRoute + + +@pytest.fixture +def device_manager(): + return MagicMock() + + +@pytest.fixture +def route(device_manager): + return AccountAttachedClientsRoute(device_manager=device_manager, middlewares=[]) + + +class TestAccountAttachedClients: + def test_returns_attached_clients(self, route, device_manager): + device_manager.get_devices.return_value = [ + { + "id": "dev1", + "name": "Desktop", + "type": "desktop", + "sessionTokenId": "token123", + "createdAt": 1000, + "lastAccessTime": 2000, + }, + ] + event = APIGatewayProxyEvent( + { + "httpMethod": "GET", + "path": "/v1/account/attached_clients", + "headers": {"authorization": 'Hawk id="token123"'}, + "queryStringParameters": {}, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert len(body) == 1 + client = body[0] + assert client["clientId"] is None + assert client["deviceId"] == "dev1" + assert client["sessionTokenId"] == "token123" + assert client["refreshTokenId"] is None + assert client["deviceType"] == "desktop" + assert client["name"] == "Desktop" + assert client["createdTime"] == 1000 + assert client["lastAccessTime"] == 2000 + assert client["scope"] is None + assert client["location"] == {} + assert client["userAgent"] == "" + assert client["os"] is None + + def test_is_current_session_set(self, route, device_manager): + device_manager.get_devices.return_value = [ + { + "id": "dev1", + "sessionTokenId": "token123", + "type": "desktop", + "name": "A", + "createdAt": 1000, + "lastAccessTime": 2000, + }, + { + "id": "dev2", + "sessionTokenId": "other-token", + "type": "mobile", + "name": "B", + "createdAt": 1500, + "lastAccessTime": 2500, + }, + ] + event = APIGatewayProxyEvent( + { + "httpMethod": "GET", + "path": "/v1/account/attached_clients", + "headers": {"authorization": 'Hawk id="token123"'}, + "queryStringParameters": {}, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + body = json.loads(response.body) + assert body[0]["isCurrentSession"] is True + assert body[1]["isCurrentSession"] is False + + def test_returns_empty_list(self, route, device_manager): + device_manager.get_devices.return_value = [] + event = APIGatewayProxyEvent( + { + "httpMethod": "GET", + "path": "/v1/account/attached_clients", + "headers": {"authorization": 'Hawk id="token123"'}, + "queryStringParameters": {}, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body == [] + + +class TestAccountAttachedClientsBind: + def test_bind_registers_get_route(self, route): + mock_api = MagicMock() + mock_api.get = MagicMock(return_value=lambda f: f) + route.bind(mock_api) + mock_api.get.assert_called_once_with("/v1/account/attached_clients", middlewares=[]) diff --git a/lambda/tests/routes/auth/test_account_device.py b/lambda/tests/routes/auth/test_account_device.py new file mode 100644 index 00000000..f727da5c --- /dev/null +++ b/lambda/tests/routes/auth/test_account_device.py @@ -0,0 +1,113 @@ +"""Unit tests for AccountDevice route""" + +import json +from unittest.mock import MagicMock + +import pytest +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent + +from src.routes.auth.account_device import AccountDeviceRoute + + +@pytest.fixture +def device_manager(): + return MagicMock() + + +@pytest.fixture +def route(device_manager): + return AccountDeviceRoute(device_manager=device_manager, middlewares=[]) + + +class TestAccountDevice: + def test_create_device_returns_200(self, route, device_manager): + device_manager.upsert_device.return_value = { + "id": "dev1", + "name": "My Firefox", + "type": "desktop", + "sessionTokenId": "token123", + "createdAt": 1000, + "lastAccessTime": 1000, + } + event = APIGatewayProxyEvent( + { + "httpMethod": "POST", + "path": "/v1/account/device", + "headers": {"authorization": 'Hawk id="token123"'}, + "body": json.dumps({"name": "My Firefox", "type": "desktop"}), + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body["id"] == "dev1" + assert body["name"] == "My Firefox" + device_manager.upsert_device.assert_called_once_with( + "uid1", + "token123", + {"name": "My Firefox", "type": "desktop"}, + ) + + def test_update_device_returns_200(self, route, device_manager): + device_manager.upsert_device.return_value = { + "id": "existing-dev", + "name": "Updated Name", + "type": "mobile", + "sessionTokenId": "token123", + "createdAt": 500, + "lastAccessTime": 2000, + } + event = APIGatewayProxyEvent( + { + "httpMethod": "POST", + "path": "/v1/account/device", + "headers": {}, + "body": json.dumps( + {"id": "existing-dev", "name": "Updated Name", "type": "mobile"} + ), + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body["id"] == "existing-dev" + assert body["name"] == "Updated Name" + device_manager.upsert_device.assert_called_once_with( + "uid1", + "token123", + {"id": "existing-dev", "name": "Updated Name", "type": "mobile"}, + ) + + def test_missing_body_returns_200(self, route, device_manager): + device_manager.upsert_device.return_value = { + "id": "auto-id", + "name": "", + "type": "desktop", + "sessionTokenId": "token123", + "createdAt": 3000, + "lastAccessTime": 3000, + } + event = APIGatewayProxyEvent( + { + "httpMethod": "POST", + "path": "/v1/account/device", + "headers": {}, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body["id"] == "auto-id" + device_manager.upsert_device.assert_called_once_with("uid1", "token123", {}) + + +class TestAccountDeviceBind: + def test_bind_registers_post_route(self, route): + mock_api = MagicMock() + mock_api.post = MagicMock(return_value=lambda f: f) + route.bind(mock_api) + mock_api.post.assert_called_once_with("/v1/account/device", middlewares=[]) diff --git a/lambda/tests/routes/auth/test_account_devices.py b/lambda/tests/routes/auth/test_account_devices.py new file mode 100644 index 00000000..da56497f --- /dev/null +++ b/lambda/tests/routes/auth/test_account_devices.py @@ -0,0 +1,99 @@ +"""Unit tests for AccountDevices route""" + +import json +from unittest.mock import MagicMock + +import pytest +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent + +from src.routes.auth.account_devices import AccountDevicesRoute + + +@pytest.fixture +def device_manager(): + return MagicMock() + + +@pytest.fixture +def route(device_manager): + return AccountDevicesRoute(device_manager=device_manager, middlewares=[]) + + +class TestAccountDevices: + def test_returns_device_list(self, route, device_manager): + device_manager.get_devices.return_value = [ + { + "id": "dev1", + "name": "Desktop", + "type": "desktop", + "sessionTokenId": "token123", + "createdAt": 1000, + "lastAccessTime": 2000, + }, + { + "id": "dev2", + "name": "Mobile", + "type": "mobile", + "sessionTokenId": "other-token", + "createdAt": 1500, + "lastAccessTime": 2500, + }, + ] + event = APIGatewayProxyEvent( + { + "httpMethod": "GET", + "path": "/v1/account/devices", + "headers": {"authorization": 'Hawk id="token123"'}, + "queryStringParameters": {}, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert len(body) == 2 + assert body[0]["isCurrentDevice"] is True + assert body[1]["isCurrentDevice"] is False + device_manager.get_devices.assert_called_once_with("uid1", None) + + def test_filters_idle_devices(self, route, device_manager): + device_manager.get_devices.return_value = [] + event = APIGatewayProxyEvent( + { + "httpMethod": "GET", + "path": "/v1/account/devices", + "headers": {"authorization": 'Hawk id="token123"'}, + "queryStringParameters": {"filterIdleDevicesTimestamp": "1609459200000"}, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + device_manager.get_devices.assert_called_once_with("uid1", 1609459200000) + + def test_returns_empty_list(self, route, device_manager): + device_manager.get_devices.return_value = [] + event = APIGatewayProxyEvent( + { + "httpMethod": "GET", + "path": "/v1/account/devices", + "headers": {"authorization": 'Hawk id="token123"'}, + "queryStringParameters": None, + "body": None, + "requestContext": {"hawk_uid": "uid1", "hawk_token_id": "token123"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body == [] + + +class TestAccountDevicesBind: + def test_bind_registers_get_route(self, route): + mock_api = MagicMock() + mock_api.get = MagicMock(return_value=lambda f: f) + route.bind(mock_api) + mock_api.get.assert_called_once_with("/v1/account/devices", middlewares=[]) diff --git a/lambda/tests/routes/auth/test_account_devices_notify.py b/lambda/tests/routes/auth/test_account_devices_notify.py new file mode 100644 index 00000000..9182c9fd --- /dev/null +++ b/lambda/tests/routes/auth/test_account_devices_notify.py @@ -0,0 +1,39 @@ +"""Unit tests for AccountDevicesNotify route""" + +import json +from unittest.mock import MagicMock + +import pytest +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent + +from src.routes.auth.account_devices_notify import AccountDevicesNotifyRoute + + +@pytest.fixture +def route(): + return AccountDevicesNotifyRoute(middlewares=[]) + + +class TestAccountDevicesNotify: + def test_returns_empty_object(self, route): + event = APIGatewayProxyEvent( + { + "httpMethod": "POST", + "path": "/v1/account/devices/notify", + "headers": {"authorization": 'Hawk id="token123"'}, + "body": json.dumps({"to": "all", "payload": {}}), + "requestContext": {"hawk_uid": "uid1"}, + } + ) + response = route.handle(event) + assert response.status_code == 200 + body = json.loads(response.body) + assert body == {} + + +class TestAccountDevicesNotifyBind: + def test_bind_registers_post_route(self, route): + mock_api = MagicMock() + mock_api.post = MagicMock(return_value=lambda f: f) + route.bind(mock_api) + mock_api.post.assert_called_once_with("/v1/account/devices/notify", middlewares=[]) diff --git a/lambda/tests/routes/auth/test_route_dispatch.py b/lambda/tests/routes/auth/test_route_dispatch.py index 7b454625..fa33bdf4 100644 --- a/lambda/tests/routes/auth/test_route_dispatch.py +++ b/lambda/tests/routes/auth/test_route_dispatch.py @@ -5,7 +5,11 @@ from unittest.mock import MagicMock +from src.routes.auth.account_attached_clients import AccountAttachedClientsRoute from src.routes.auth.account_create import AccountCreateRoute +from src.routes.auth.account_device import AccountDeviceRoute +from src.routes.auth.account_devices import AccountDevicesRoute +from src.routes.auth.account_devices_notify import AccountDevicesNotifyRoute from src.routes.auth.account_keys import AccountKeysRoute from src.routes.auth.account_login import AccountLoginRoute from src.routes.auth.account_status import AccountStatusRoute @@ -25,6 +29,7 @@ def _make_event(method, path, headers=None, body=None, qs=None, hawk_uid=None): ctx = {"requestId": "test"} if hawk_uid: ctx["hawk_uid"] = hawk_uid + ctx["hawk_token_id"] = "test-token-id" return { "httpMethod": method, "path": path, @@ -171,3 +176,41 @@ def test_oidc_code_exchange_dispatches(self): _make_event("POST", "/v1/oidc/exchange", body="{}"), _make_context() ) assert result["statusCode"] == 400 + + def test_account_device_dispatches(self): + mgr = MagicMock() + mgr.upsert_device.return_value = {"id": "dev1", "name": "Test"} + route = AccountDeviceRoute(device_manager=mgr, middlewares=[]) + result = _router(route).handler( + _make_event("POST", "/v1/account/device", body="{}", hawk_uid="uid1"), + _make_context(), + ) + assert result["statusCode"] == 200 + + def test_account_devices_dispatches(self): + mgr = MagicMock() + mgr.get_devices.return_value = [] + route = AccountDevicesRoute(device_manager=mgr, middlewares=[]) + result = _router(route).handler( + _make_event("GET", "/v1/account/devices", hawk_uid="uid1"), + _make_context(), + ) + assert result["statusCode"] == 200 + + def test_account_attached_clients_dispatches(self): + mgr = MagicMock() + mgr.get_devices.return_value = [] + route = AccountAttachedClientsRoute(device_manager=mgr, middlewares=[]) + result = _router(route).handler( + _make_event("GET", "/v1/account/attached_clients", hawk_uid="uid1"), + _make_context(), + ) + assert result["statusCode"] == 200 + + def test_account_devices_notify_dispatches(self): + route = AccountDevicesNotifyRoute(middlewares=[]) + result = _router(route).handler( + _make_event("POST", "/v1/account/devices/notify", body="{}", hawk_uid="uid1"), + _make_context(), + ) + assert result["statusCode"] == 200 diff --git a/lambda/tests/services/test_device_manager.py b/lambda/tests/services/test_device_manager.py new file mode 100644 index 00000000..5a81b285 --- /dev/null +++ b/lambda/tests/services/test_device_manager.py @@ -0,0 +1,273 @@ +"""Unit tests for DeviceManager with DynamoDB stubber""" + +from unittest.mock import ANY, patch + +import pytest + +from src.services.device_manager import DeviceManager + + +class TestDeviceManager: + """Test DeviceManager DynamoDB operations""" + + @pytest.fixture + def manager(self, dynamodb_table): + """Create DeviceManager instance with stubbed table""" + return DeviceManager(table=dynamodb_table) + + @pytest.fixture + def sample_uid(self): + return "abcdef1234567890abcdef1234567890" + + @pytest.fixture + def sample_session_token_id(self): + return "session-token-id-abc123" + + @pytest.fixture + def mock_time(self): + """Mock time.time() for device_manager""" + with patch("src.services.device_manager.time") as mock: + mock.time.return_value = 1000000.0 + yield mock + + @pytest.fixture + def mock_uuid(self): + """Mock uuid.uuid4() for device_manager""" + with patch("src.services.device_manager.uuid") as mock: + mock_uuid4 = mock.uuid4.return_value + mock_uuid4.hex = "aabbccdd11223344aabbccdd11223344" + yield mock + + # -- upsert_device (create) ------------------------------------------------ + + def test_upsert_device_creates_new( + self, + manager, + dynamodb_stubber, + storage_table_name, + sample_uid, + sample_session_token_id, + mock_time, + mock_uuid, + ): + """upsert_device without id generates UUID and stores new device""" + generated_id = "aabbccdd11223344aabbccdd11223344" + + # Stub put_item for the new device + dynamodb_stubber.add_response( + "put_item", + {}, + { + "TableName": storage_table_name, + "Item": { + "PK": f"DEVICE#{sample_uid}#{generated_id}", + "id": generated_id, + "name": "My Phone", + "type": "mobile", + "pushCallback": None, + "pushPublicKey": None, + "pushAuthKey": None, + "pushEndpointExpired": False, + "availableCommands": {}, + "sessionTokenId": sample_session_token_id, + "createdAt": 1000000000, + "lastAccessTime": 1000000000, + }, + }, + ) + + result = manager.upsert_device( + uid=sample_uid, + session_token_id=sample_session_token_id, + data={"name": "My Phone", "type": "mobile"}, + ) + + assert result["id"] == generated_id + assert result["name"] == "My Phone" + assert result["type"] == "mobile" + assert result["createdAt"] == 1000000000 + assert result["lastAccessTime"] == 1000000000 + assert result["sessionTokenId"] == sample_session_token_id + dynamodb_stubber.assert_no_pending_responses() + + # -- upsert_device (update) ------------------------------------------------ + + def test_upsert_device_updates_existing( + self, + manager, + dynamodb_stubber, + storage_table_name, + sample_uid, + sample_session_token_id, + mock_time, + ): + """upsert_device with id merges fields into existing device""" + device_id = "existing-device-id-00000000000000" + + # Stub get_item for the existing device + dynamodb_stubber.add_response( + "get_item", + { + "Item": { + "PK": {"S": f"DEVICE#{sample_uid}#{device_id}"}, + "id": {"S": device_id}, + "name": {"S": "Old Name"}, + "type": {"S": "desktop"}, + "createdAt": {"N": "999000000"}, + "lastAccessTime": {"N": "999000000"}, + "sessionTokenId": {"S": "old-session-token"}, + }, + }, + { + "TableName": storage_table_name, + "Key": {"PK": f"DEVICE#{sample_uid}#{device_id}"}, + }, + ) + + # Stub put_item for the updated device + dynamodb_stubber.add_response( + "put_item", + {}, + { + "TableName": storage_table_name, + "Item": { + "PK": f"DEVICE#{sample_uid}#{device_id}", + "id": device_id, + "name": "New Name", + "type": "desktop", + "createdAt": 999000000, + "lastAccessTime": 1000000000, + "sessionTokenId": sample_session_token_id, + }, + }, + ) + + result = manager.upsert_device( + uid=sample_uid, + session_token_id=sample_session_token_id, + data={"id": device_id, "name": "New Name"}, + ) + + assert result["id"] == device_id + assert result["name"] == "New Name" + assert result["type"] == "desktop" + assert result["createdAt"] == 999000000 + assert result["lastAccessTime"] == 1000000000 + assert result["sessionTokenId"] == sample_session_token_id + dynamodb_stubber.assert_no_pending_responses() + + # -- get_devices ----------------------------------------------------------- + + def test_get_devices_returns_all( + self, + manager, + dynamodb_stubber, + storage_table_name, + sample_uid, + ): + """get_devices returns all devices for a user""" + device_id_1 = "device-1-00000000000000000000" + device_id_2 = "device-2-00000000000000000000" + + # Stub scan + dynamodb_stubber.add_response( + "scan", + { + "Items": [ + { + "PK": {"S": f"DEVICE#{sample_uid}#{device_id_1}"}, + "id": {"S": device_id_1}, + "name": {"S": "Phone"}, + "lastAccessTime": {"N": "2000000000"}, + }, + { + "PK": {"S": f"DEVICE#{sample_uid}#{device_id_2}"}, + "id": {"S": device_id_2}, + "name": {"S": "Laptop"}, + "lastAccessTime": {"N": "2000000000"}, + }, + ], + }, + { + "TableName": storage_table_name, + "FilterExpression": ANY, + }, + ) + + devices = manager.get_devices(uid=sample_uid) + + assert len(devices) == 2 + assert devices[0]["id"] == device_id_1 + assert devices[0]["name"] == "Phone" + assert devices[1]["id"] == device_id_2 + assert devices[1]["name"] == "Laptop" + # PK should be stripped + assert "PK" not in devices[0] + assert "PK" not in devices[1] + dynamodb_stubber.assert_no_pending_responses() + + def test_get_devices_filters_idle( + self, + manager, + dynamodb_stubber, + storage_table_name, + sample_uid, + ): + """get_devices excludes devices with lastAccessTime below threshold""" + device_id_active = "device-active-0000000000000000" + device_id_idle = "device-idle-00000000000000000" + + # Stub scan returning both active and idle devices + dynamodb_stubber.add_response( + "scan", + { + "Items": [ + { + "PK": {"S": f"DEVICE#{sample_uid}#{device_id_active}"}, + "id": {"S": device_id_active}, + "name": {"S": "Active Phone"}, + "lastAccessTime": {"N": "2000000000"}, + }, + { + "PK": {"S": f"DEVICE#{sample_uid}#{device_id_idle}"}, + "id": {"S": device_id_idle}, + "name": {"S": "Idle Laptop"}, + "lastAccessTime": {"N": "1000000000"}, + }, + ], + }, + { + "TableName": storage_table_name, + "FilterExpression": ANY, + }, + ) + + devices = manager.get_devices(uid=sample_uid, filter_idle_timestamp=1500000000) + + assert len(devices) == 1 + assert devices[0]["id"] == device_id_active + assert devices[0]["name"] == "Active Phone" + dynamodb_stubber.assert_no_pending_responses() + + def test_get_devices_empty( + self, + manager, + dynamodb_stubber, + storage_table_name, + sample_uid, + ): + """get_devices returns empty list when scan returns no items""" + # Stub scan returning no items + dynamodb_stubber.add_response( + "scan", + {"Items": []}, + { + "TableName": storage_table_name, + "FilterExpression": ANY, + }, + ) + + devices = manager.get_devices(uid=sample_uid) + + assert devices == [] + dynamodb_stubber.assert_no_pending_responses() diff --git a/lambda/tests/services/test_fxa_token_manager.py b/lambda/tests/services/test_fxa_token_manager.py index 31acfc88..81276172 100644 --- a/lambda/tests/services/test_fxa_token_manager.py +++ b/lambda/tests/services/test_fxa_token_manager.py @@ -1473,3 +1473,17 @@ def test_custom_keyfetch_ttl( manager.create_key_fetch_token("test-uid") dynamodb_stubber.assert_no_pending_responses() + + +class TestExtractTokenIdFromHawkHeader: + def test_valid_header(self): + result = FxATokenManager.extract_token_id_from_hawk_header( + 'Hawk id="abc123", ts="1234567890", nonce="xyz"' + ) + assert result == "abc123" + + def test_empty_header(self): + assert FxATokenManager.extract_token_id_from_hawk_header("") is None + + def test_no_match(self): + assert FxATokenManager.extract_token_id_from_hawk_header("Bearer token") is None diff --git a/lambda/tests/services/test_storage_hawk_middleware.py b/lambda/tests/services/test_storage_hawk_middleware.py index 0d4ee339..eb17b3d7 100644 --- a/lambda/tests/services/test_storage_hawk_middleware.py +++ b/lambda/tests/services/test_storage_hawk_middleware.py @@ -251,6 +251,7 @@ def test_session_hawk_success_injects_hawk_uid(self): result = middleware.handler(app, mock_next) assert raw_event["requestContext"]["hawk_uid"] == "uid123" + assert raw_event["requestContext"]["hawk_token_id"] == "tokenid" mock_next.assert_called_once_with(app) assert result.status_code == 200