Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lambda/src/environment/service_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import boto3
from aws_lambda_powertools.event_handler import CORSConfig, Response

from src.routes.auth.account_attached_clients import AccountAttachedClientsRoute
from src.routes.auth.account_create import AccountCreateRoute
from src.routes.auth.account_device import AccountDeviceRoute
from src.routes.auth.account_devices import AccountDevicesRoute
from src.routes.auth.account_devices_notify import AccountDevicesNotifyRoute
from src.routes.auth.account_keys import AccountKeysRoute
from src.routes.auth.account_login import AccountLoginRoute
from src.routes.auth.account_status import AccountStatusRoute
Expand Down Expand Up @@ -46,6 +50,7 @@
)
from src.services.auth_account_manager import AuthAccountManager
from src.services.channel_service import ChannelService
from src.services.device_manager import DeviceManager
from src.services.fxa_token_manager import FxATokenManager
from src.services.hawk_service import HawkService
from src.services.jwt_service import JWTService
Expand Down Expand Up @@ -268,6 +273,10 @@ def kms_client(self): # pragma: nocover
def auth_account_manager(self) -> AuthAccountManager:
return AuthAccountManager(table=self.auth_table)

@cached_property
def device_manager(self) -> DeviceManager:
return DeviceManager(table=self.auth_table)

@cached_property
def fxa_token_manager(self) -> FxATokenManager:
return FxATokenManager(table=self.auth_table)
Expand Down Expand Up @@ -351,6 +360,22 @@ def auth_api_router(self):
oidc_validator=self.oidc_validator,
account_manager=self.auth_account_manager,
),
# Device management routes
AccountDeviceRoute(
device_manager=self.device_manager,
middlewares=[self.session_hawk_middleware],
),
AccountDevicesRoute(
device_manager=self.device_manager,
middlewares=[self.session_hawk_middleware],
),
AccountAttachedClientsRoute(
device_manager=self.device_manager,
middlewares=[self.session_hawk_middleware],
),
AccountDevicesNotifyRoute(
middlewares=[self.session_hawk_middleware],
),
],
middlewares=[WeaveTimestampMiddleware()],
cors=self.cors_config,
Expand Down
6 changes: 5 additions & 1 deletion lambda/src/middlewares/hawk_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
authentication. Pass hawk_service for storage API, token_manager for auth API
session-authenticated routes.

On success, injects ``hawk_uid`` into event["requestContext"].
On success, injects ``hawk_uid`` and ``hawk_token_id`` into event["requestContext"].
On failure, raises HawkAuthenticationError (handle with router exception handler).
"""

Expand Down Expand Up @@ -85,3 +85,7 @@ def _validate_session_hawk(self, event, auth_header, method, path, host, port):
raise HawkAuthenticationError("Invalid or expired session token")

event["requestContext"]["hawk_uid"] = uid
# Inject the Hawk id (session token ID) for device correlation
event["requestContext"]["hawk_token_id"] = (
FxATokenManager.extract_token_id_from_hawk_header(auth_header) or ""
)
58 changes: 58 additions & 0 deletions lambda/src/routes/auth/account_attached_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""AccountAttachedClients route — GET /v1/account/attached_clients"""

import json
from typing import Sequence

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler

from src.services.device_manager import DeviceManager
from src.shared.base_route import BaseRoute


class AccountAttachedClientsRoute(BaseRoute):
"""Return attached clients derived from device records."""

def __init__(
self,
device_manager: DeviceManager,
middlewares: Sequence[BaseMiddlewareHandler] = (),
):
self._device_manager = device_manager
self.middlewares = middlewares

def bind(self, app: APIGatewayRestResolver):
@app.get("/v1/account/attached_clients", middlewares=list(self.middlewares))
def handle_account_attached_clients():
return self.handle(app.current_event)

def handle(self, event) -> Response:
uid = event["requestContext"]["hawk_uid"]
session_token_id = event["requestContext"].get("hawk_token_id", "")

devices = self._device_manager.get_devices(uid)
clients = []
for d in devices:
clients.append(
{
"clientId": None,
"deviceId": d.get("id"),
"sessionTokenId": d.get("sessionTokenId"),
"refreshTokenId": None,
"isCurrentSession": d.get("sessionTokenId") == session_token_id,
"deviceType": d.get("type"),
"name": d.get("name"),
"createdTime": d.get("createdAt"),
"lastAccessTime": d.get("lastAccessTime"),
"scope": None,
"location": {},
"userAgent": "",
"os": None,
}
)

return Response(
status_code=200,
content_type="application/json",
body=json.dumps(clients),
)
39 changes: 39 additions & 0 deletions lambda/src/routes/auth/account_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""AccountDevice route — POST /v1/account/device"""

import json
from typing import Sequence

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler

from src.services.device_manager import DeviceManager
from src.shared.base_route import BaseRoute


class AccountDeviceRoute(BaseRoute):
"""Create or update a device registration."""

def __init__(
self,
device_manager: DeviceManager,
middlewares: Sequence[BaseMiddlewareHandler] = (),
):
self._device_manager = device_manager
self.middlewares = middlewares

def bind(self, app: APIGatewayRestResolver):
@app.post("/v1/account/device", middlewares=list(self.middlewares))
def handle_account_device():
return self.handle(app.current_event)

def handle(self, event) -> Response:
uid = event["requestContext"]["hawk_uid"]
session_token_id = event["requestContext"].get("hawk_token_id", "")
body = json.loads(event.body or "{}")
device = self._device_manager.upsert_device(uid, session_token_id, body)

return Response(
status_code=200,
content_type="application/json",
body=json.dumps(device),
)
45 changes: 45 additions & 0 deletions lambda/src/routes/auth/account_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""AccountDevices route — GET /v1/account/devices"""

import json
from typing import Sequence

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler

from src.services.device_manager import DeviceManager
from src.shared.base_route import BaseRoute


class AccountDevicesRoute(BaseRoute):
"""List all devices for the authenticated user."""

def __init__(
self,
device_manager: DeviceManager,
middlewares: Sequence[BaseMiddlewareHandler] = (),
):
self._device_manager = device_manager
self.middlewares = middlewares

def bind(self, app: APIGatewayRestResolver):
@app.get("/v1/account/devices", middlewares=list(self.middlewares))
def handle_account_devices():
return self.handle(app.current_event)

def handle(self, event) -> Response:
uid = event["requestContext"]["hawk_uid"]
session_token_id = event["requestContext"].get("hawk_token_id", "")

params = event.query_string_parameters or {}
filter_ts = params.get("filterIdleDevicesTimestamp")
filter_idle = int(filter_ts) if filter_ts else None

devices = self._device_manager.get_devices(uid, filter_idle)
for d in devices:
d["isCurrentDevice"] = d.get("sessionTokenId") == session_token_id

return Response(
status_code=200,
content_type="application/json",
body=json.dumps(devices),
)
28 changes: 28 additions & 0 deletions lambda/src/routes/auth/account_devices_notify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""AccountDevicesNotify route — POST /v1/account/devices/notify (no-op)"""

import json
from typing import Sequence

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler

from src.shared.base_route import BaseRoute


class AccountDevicesNotifyRoute(BaseRoute):
"""No-op push notification endpoint for device-to-device messaging."""

def __init__(self, middlewares: Sequence[BaseMiddlewareHandler] = ()):
self.middlewares = middlewares

def bind(self, app: APIGatewayRestResolver):
@app.post("/v1/account/devices/notify", middlewares=list(self.middlewares))
def handle_account_devices_notify():
return self.handle(app.current_event)

def handle(self, event) -> Response:
return Response(
status_code=200,
content_type="application/json",
body=json.dumps({}),
)
68 changes: 68 additions & 0 deletions lambda/src/services/device_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Device management for FxA device registration and listing."""

import re
import time
import uuid
from typing import Optional

from boto3.dynamodb.conditions import Attr

DEVICE_PREFIX = "DEVICE"
_HAWK_ID_PATTERN = re.compile(r'id="([^"]+)"')


class DeviceManager:
"""Manages FxA device records in DynamoDB."""

def __init__(self, table):
self.table = table

def _device_pk(self, uid: str, device_id: str) -> str:
return f"{DEVICE_PREFIX}#{uid}#{device_id}"

def upsert_device(self, uid: str, session_token_id: str, data: dict) -> dict:
"""Create or update a device record."""
now = int(time.time() * 1000)
device_id = data.get("id")

if device_id:
# Update: get existing, merge new fields
response = self.table.get_item(Key={"PK": self._device_pk(uid, device_id)})
existing = response.get("Item", {})
existing.pop("PK", None)
# Merge: new data overwrites existing, but preserve createdAt
device = {**existing, **{k: v for k, v in data.items() if v is not None}}
device["lastAccessTime"] = now
device["sessionTokenId"] = session_token_id
else:
# Create new device
device_id = uuid.uuid4().hex
device = {
"id": device_id,
"name": data.get("name", ""),
"type": data.get("type", "desktop"),
"pushCallback": data.get("pushCallback"),
"pushPublicKey": data.get("pushPublicKey"),
"pushAuthKey": data.get("pushAuthKey"),
"pushEndpointExpired": False,
"availableCommands": data.get("availableCommands", {}),
"sessionTokenId": session_token_id,
"createdAt": now,
"lastAccessTime": now,
}

self.table.put_item(Item={"PK": self._device_pk(uid, device_id), **device})
return device

def get_devices(self, uid: str, filter_idle_timestamp: Optional[int] = None) -> list[dict]:
"""List all devices for a user."""
response = self.table.scan(
FilterExpression=Attr("PK").begins_with(f"{DEVICE_PREFIX}#{uid}#")
)
devices = []
for item in response.get("Items", []):
item.pop("PK", None)
if filter_idle_timestamp and item.get("lastAccessTime", 0) < filter_idle_timestamp:
continue
devices.append(item)
return devices
11 changes: 11 additions & 0 deletions lambda/src/services/fxa_token_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""FxA Token Manager for session tokens and key-fetch tokens in DynamoDB"""

import re
import time
from typing import Optional

Expand All @@ -19,10 +20,20 @@
SESSION_PREFIX = "SESSION"
KEYFETCH_PREFIX = "KEYFETCH"

_HAWK_ID_PATTERN = re.compile(r'id="([^"]+)"')


class FxATokenManager:
"""Manages FxA session tokens and key-fetch tokens in DynamoDB"""

@staticmethod
def extract_token_id_from_hawk_header(authorization_header: str) -> str | None:
"""Extract the Hawk id field (session token ID) from an Authorization header."""
if not authorization_header:
return None
match = _HAWK_ID_PATTERN.search(authorization_header)
return match.group(1) if match else None

def __init__(
self,
table,
Expand Down
Loading
Loading