Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lambda/src/middlewares/hawk_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def _validate_storage_hawk(self, event, auth_header, method, path, host, port):
"""Validate storage Hawk token and check URL uid matches authenticated user."""
assert self._hawk_service is not None
try:
creds = self._hawk_service.validate(auth_header, method, path, host, port)
query_params = event.query_string_parameters
creds = self._hawk_service.validate(
auth_header, method, path, host, port, query_params=query_params
)
except Exception as e:
raise HawkAuthenticationError(str(e)) from e

Expand Down
94 changes: 85 additions & 9 deletions lambda/src/services/hawk_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

import base64
import binascii
import hashlib
import hmac
import re
import secrets
import time
from dataclasses import dataclass
from itertools import permutations
from typing import Optional, Tuple

import mohawk
Expand All @@ -30,6 +33,7 @@

# Pattern to extract hawk id from header without full parse
_HAWK_ID_PATTERN = re.compile(r'id="([^"]+)"')
_HAWK_FIELD_PATTERN = re.compile(r'(\w+)="([^"]*)"')


@dataclass
Expand Down Expand Up @@ -63,13 +67,23 @@ def __init__(
self.token_duration = token_duration

def validate(
self, authorization_header: str, method: str, path: str, host: str, port: int
self,
authorization_header: str,
method: str,
path: str,
host: str,
port: int,
query_params: Optional[dict] = None,
) -> HawkCredentials:
"""
Validate HAWK Authorization header and return credentials.

Uses mohawk.Receiver for header parsing, timestamp validation, and MAC verification.
Custom business logic (expiry, generation) runs in the credentials_map callback.

When query_params are provided and contain multiple keys, pre-computes
the correct query string ordering before calling mohawk (API Gateway REST
API v1 may alphabetize query parameters, breaking the Hawk MAC).
"""
# Extract hawk_id for pre-validation of custom business logic
hawk_id = self._extract_hawk_id(authorization_header)
Expand All @@ -79,6 +93,15 @@ def validate(
if not self.validate_hawk_id_expiry(expiry):
raise ExpiredHawkTokenException(f"HAWK token expired at {expiry}")

# If query params may have been reordered, find the correct order
# before calling mohawk (to avoid consuming the nonce on a wrong URL)
if query_params and len(query_params) > 1:
corrected = self._correct_query_order(
authorization_header, hawk_id, method, path, query_params, host, port
)
if corrected:
path = corrected

# Credentials lookup called by mohawk during MAC verification
def credentials_map(sender_id):
hawk_key, cached_user_id, cached_generation = self.get_hawk_key_from_cache(sender_id)
Expand Down Expand Up @@ -106,14 +129,6 @@ def credentials_map(sender_id):
except mohawk.exc.BadHeaderValue as e:
raise InvalidHawkHeaderException(str(e))
except mohawk.exc.MacMismatch:
logger.error(
"HAWK MAC mismatch — query string ordering issue?",
extra={
"method": method,
"server_url": f"https://{host}:{port}{path}",
"path": path,
},
)
raise InvalidHawkSignatureException("HAWK MAC verification failed")
except mohawk.exc.TokenExpired:
raise InvalidHawkSignatureException("Timestamp outside acceptable window")
Expand Down Expand Up @@ -154,6 +169,67 @@ def _seen_nonce(self, sender_id, nonce, timestamp):
return True # Replay detected
raise

def _correct_query_order(
self,
authorization_header: str,
hawk_id: str,
method: str,
path: str,
query_params: dict,
host: str,
port: int,
) -> Optional[str]:
"""Find the query string ordering the client used for Hawk MAC computation.

API Gateway REST API v1 may reorder query parameters alphabetically.
This pre-computes the Hawk MAC for each permutation to find the
client's original ordering, before mohawk consumes the nonce.

Returns the corrected path (with query string), or None if the
current ordering already matches or no permutation matches.
"""
fields = self._parse_hawk_fields(authorization_header)
if not fields or "mac" not in fields or "ts" not in fields or "nonce" not in fields:
return None

try:
hawk_key, _, _ = self.get_hawk_key_from_cache(hawk_id)
except (AuthenticationException, ClientError):
return None # Let mohawk handle the error

client_mac = fields["mac"]
base_path = path.split("?")[0]
items = list(query_params.items())

for perm in permutations(items):
qs = "&".join(f"{k}={v}" for k, v in perm)
resource = f"{base_path}?{qs}"
normalized = (
f"hawk.1.header\n{fields['ts']}\n{fields['nonce']}\n"
f"{method}\n{resource}\n{host}\n{port}\n"
f"{fields.get('hash', '')}\n{fields.get('ext', '')}\n"
)
computed = base64.b64encode(
hmac.new(
hawk_key.encode("ascii"), normalized.encode("ascii"), hashlib.sha256
).digest()
).decode("ascii")

if hmac.compare_digest(computed, client_mac):
if resource != path:
logger.info(
"Corrected query string order",
extra={"original": path, "corrected": resource},
)
return resource

return None # No permutation matched; let mohawk report the error

@staticmethod
def _parse_hawk_fields(authorization_header: str) -> dict:
"""Parse all key=value fields from a Hawk Authorization header."""
return dict(_HAWK_FIELD_PATTERN.findall(authorization_header))

def _extract_hawk_id(self, authorization_header: str) -> str:
"""Extract the id field from a Hawk Authorization header."""
if not authorization_header or not authorization_header.startswith("Hawk "):
Expand Down
12 changes: 0 additions & 12 deletions lambda/src/shared/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import json
import logging
from datetime import datetime, timezone
from decimal import Decimal

logger = logging.getLogger(__name__)


def datetime_encoder(dt: datetime) -> Decimal:
"""Convert datetime to Unix timestamp (Decimal) for DynamoDB serialization"""
Expand Down Expand Up @@ -68,17 +65,8 @@ def extract_hawk_request_params(event) -> tuple[str, str, str, int]:

query_params = event.query_string_parameters
if query_params:
param_keys = list(query_params.keys())
qs = "&".join(f"{k}={v}" for k, v in query_params.items())
path = f"{path}?{qs}"
logger.info(
"Hawk query string reconstruction",
extra={
"original_path": event.path,
"reconstructed_path": path,
"param_key_order": param_keys,
},
)

try:
host = event.request_context.domain_name or "localhost"
Expand Down
Loading
Loading