diff --git a/lambda/src/middlewares/hawk_auth.py b/lambda/src/middlewares/hawk_auth.py index 43fae705..13d3cfc5 100644 --- a/lambda/src/middlewares/hawk_auth.py +++ b/lambda/src/middlewares/hawk_auth.py @@ -60,7 +60,10 @@ def _validate_storage_hawk(self, event, auth_header, method, path, host, port): """Validate storage Hawk token and check URL uid matches authenticated user.""" assert self._hawk_service is not None try: - creds = self._hawk_service.validate(auth_header, method, path, host, port) + query_params = event.query_string_parameters + creds = self._hawk_service.validate( + auth_header, method, path, host, port, query_params=query_params + ) except Exception as e: raise HawkAuthenticationError(str(e)) from e diff --git a/lambda/src/services/hawk_service.py b/lambda/src/services/hawk_service.py index a3bf4252..94d7e3a6 100644 --- a/lambda/src/services/hawk_service.py +++ b/lambda/src/services/hawk_service.py @@ -7,10 +7,13 @@ import base64 import binascii +import hashlib +import hmac import re import secrets import time from dataclasses import dataclass +from itertools import permutations from typing import Optional, Tuple import mohawk @@ -30,6 +33,7 @@ # Pattern to extract hawk id from header without full parse _HAWK_ID_PATTERN = re.compile(r'id="([^"]+)"') +_HAWK_FIELD_PATTERN = re.compile(r'(\w+)="([^"]*)"') @dataclass @@ -63,13 +67,23 @@ def __init__( self.token_duration = token_duration def validate( - self, authorization_header: str, method: str, path: str, host: str, port: int + self, + authorization_header: str, + method: str, + path: str, + host: str, + port: int, + query_params: Optional[dict] = None, ) -> HawkCredentials: """ Validate HAWK Authorization header and return credentials. Uses mohawk.Receiver for header parsing, timestamp validation, and MAC verification. Custom business logic (expiry, generation) runs in the credentials_map callback. + + When query_params are provided and contain multiple keys, pre-computes + the correct query string ordering before calling mohawk (API Gateway REST + API v1 may alphabetize query parameters, breaking the Hawk MAC). """ # Extract hawk_id for pre-validation of custom business logic hawk_id = self._extract_hawk_id(authorization_header) @@ -79,6 +93,15 @@ def validate( if not self.validate_hawk_id_expiry(expiry): raise ExpiredHawkTokenException(f"HAWK token expired at {expiry}") + # If query params may have been reordered, find the correct order + # before calling mohawk (to avoid consuming the nonce on a wrong URL) + if query_params and len(query_params) > 1: + corrected = self._correct_query_order( + authorization_header, hawk_id, method, path, query_params, host, port + ) + if corrected: + path = corrected + # Credentials lookup called by mohawk during MAC verification def credentials_map(sender_id): hawk_key, cached_user_id, cached_generation = self.get_hawk_key_from_cache(sender_id) @@ -106,14 +129,6 @@ def credentials_map(sender_id): except mohawk.exc.BadHeaderValue as e: raise InvalidHawkHeaderException(str(e)) except mohawk.exc.MacMismatch: - logger.error( - "HAWK MAC mismatch — query string ordering issue?", - extra={ - "method": method, - "server_url": f"https://{host}:{port}{path}", - "path": path, - }, - ) raise InvalidHawkSignatureException("HAWK MAC verification failed") except mohawk.exc.TokenExpired: raise InvalidHawkSignatureException("Timestamp outside acceptable window") @@ -154,6 +169,67 @@ def _seen_nonce(self, sender_id, nonce, timestamp): return True # Replay detected raise + def _correct_query_order( + self, + authorization_header: str, + hawk_id: str, + method: str, + path: str, + query_params: dict, + host: str, + port: int, + ) -> Optional[str]: + """Find the query string ordering the client used for Hawk MAC computation. + + API Gateway REST API v1 may reorder query parameters alphabetically. + This pre-computes the Hawk MAC for each permutation to find the + client's original ordering, before mohawk consumes the nonce. + + Returns the corrected path (with query string), or None if the + current ordering already matches or no permutation matches. + """ + fields = self._parse_hawk_fields(authorization_header) + if not fields or "mac" not in fields or "ts" not in fields or "nonce" not in fields: + return None + + try: + hawk_key, _, _ = self.get_hawk_key_from_cache(hawk_id) + except (AuthenticationException, ClientError): + return None # Let mohawk handle the error + + client_mac = fields["mac"] + base_path = path.split("?")[0] + items = list(query_params.items()) + + for perm in permutations(items): + qs = "&".join(f"{k}={v}" for k, v in perm) + resource = f"{base_path}?{qs}" + normalized = ( + f"hawk.1.header\n{fields['ts']}\n{fields['nonce']}\n" + f"{method}\n{resource}\n{host}\n{port}\n" + f"{fields.get('hash', '')}\n{fields.get('ext', '')}\n" + ) + computed = base64.b64encode( + hmac.new( + hawk_key.encode("ascii"), normalized.encode("ascii"), hashlib.sha256 + ).digest() + ).decode("ascii") + + if hmac.compare_digest(computed, client_mac): + if resource != path: + logger.info( + "Corrected query string order", + extra={"original": path, "corrected": resource}, + ) + return resource + + return None # No permutation matched; let mohawk report the error + + @staticmethod + def _parse_hawk_fields(authorization_header: str) -> dict: + """Parse all key=value fields from a Hawk Authorization header.""" + return dict(_HAWK_FIELD_PATTERN.findall(authorization_header)) + def _extract_hawk_id(self, authorization_header: str) -> str: """Extract the id field from a Hawk Authorization header.""" if not authorization_header or not authorization_header.startswith("Hawk "): diff --git a/lambda/src/shared/utils.py b/lambda/src/shared/utils.py index e4ed11d9..cc605e80 100644 --- a/lambda/src/shared/utils.py +++ b/lambda/src/shared/utils.py @@ -1,10 +1,7 @@ import json -import logging from datetime import datetime, timezone from decimal import Decimal -logger = logging.getLogger(__name__) - def datetime_encoder(dt: datetime) -> Decimal: """Convert datetime to Unix timestamp (Decimal) for DynamoDB serialization""" @@ -68,17 +65,8 @@ def extract_hawk_request_params(event) -> tuple[str, str, str, int]: query_params = event.query_string_parameters if query_params: - param_keys = list(query_params.keys()) qs = "&".join(f"{k}={v}" for k, v in query_params.items()) path = f"{path}?{qs}" - logger.info( - "Hawk query string reconstruction", - extra={ - "original_path": event.path, - "reconstructed_path": path, - "param_key_order": param_keys, - }, - ) try: host = event.request_context.domain_name or "localhost" diff --git a/lambda/tests/services/test_hawk_service.py b/lambda/tests/services/test_hawk_service.py index d22212e6..eb7ad38a 100644 --- a/lambda/tests/services/test_hawk_service.py +++ b/lambda/tests/services/test_hawk_service.py @@ -534,6 +534,112 @@ def test_validate_rejects_replayed_nonce(self, hawk_service, mock_dynamodb_table hawk_service.validate(header, "GET", "/test", "host", 443) +class TestValidateQueryParamCorrection: + """Tests for query parameter reordering correction in validate()""" + + def test_validate_corrects_reordered_query_params(self, hawk_service, mock_dynamodb_table): + """validate() succeeds when API Gateway alphabetizes query params.""" + user_id = "user123" + generation = 5 + expiry = int(time.time()) + 300 + hawk_id = ( + base64.urlsafe_b64encode(f"{user_id}:{generation}:{expiry}".encode()) + .decode() + .rstrip("=") + ) + hawk_key = "a" * 64 + method = "GET" + host = "api.example.com" + port = 443 + + # Client sends: newer=1.09&full=1&limit=1000 + client_path = "/storage/prefs?newer=1.09&full=1&limit=1000" + authorization_header = build_hawk_auth_header( + hawk_id, hawk_key, method, client_path, host, port + ) + + mock_dynamodb_table.get_item.return_value = { + "Item": {"hawk_key": hawk_key, "user_id": user_id, "generation": generation} + } + mock_dynamodb_table.put_item.return_value = {} + + # Server receives alphabetized path + server_path = "/storage/prefs?full=1&limit=1000&newer=1.09" + query_params = {"full": "1", "limit": "1000", "newer": "1.09"} + + creds = hawk_service.validate( + authorization_header, method, server_path, host, port, query_params=query_params + ) + assert creds.user_id == user_id + + def test_validate_correction_returns_none_falls_through( + self, hawk_service, mock_dynamodb_table + ): + """When no permutation matches, validate proceeds with original path (and fails).""" + user_id = "user1" + generation = 1 + expiry = int(time.time()) + 300 + hawk_id = ( + base64.urlsafe_b64encode(f"{user_id}:{generation}:{expiry}".encode()) + .decode() + .rstrip("=") + ) + hawk_key = "c" * 64 + method = "GET" + host = "api.example.com" + port = 443 + + # Build header with a path that won't match any permutation of the query params + # (client used a completely different path) + client_path = "/storage/other?x=1&y=2" + authorization_header = build_hawk_auth_header( + hawk_id, hawk_key, method, client_path, host, port + ) + + mock_dynamodb_table.get_item.return_value = { + "Item": {"hawk_key": hawk_key, "user_id": user_id, "generation": generation} + } + mock_dynamodb_table.put_item.return_value = {} + + # Server has different params — no permutation will match + server_path = "/storage/prefs?a=1&b=2" + query_params = {"a": "1", "b": "2"} + + with pytest.raises(InvalidHawkSignatureException): + hawk_service.validate( + authorization_header, method, server_path, host, port, query_params=query_params + ) + + def test_validate_no_correction_needed_for_single_param( + self, hawk_service, mock_dynamodb_table + ): + """Single query param doesn't trigger permutation logic.""" + user_id = "user1" + generation = 1 + expiry = int(time.time()) + 300 + hawk_id = ( + base64.urlsafe_b64encode(f"{user_id}:{generation}:{expiry}".encode()) + .decode() + .rstrip("=") + ) + hawk_key = "b" * 64 + method = "GET" + host = "api.example.com" + port = 443 + path = "/storage/tabs?full=1" + + authorization_header = build_hawk_auth_header(hawk_id, hawk_key, method, path, host, port) + mock_dynamodb_table.get_item.return_value = { + "Item": {"hawk_key": hawk_key, "user_id": user_id, "generation": generation} + } + mock_dynamodb_table.put_item.return_value = {} + + creds = hawk_service.validate( + authorization_header, method, path, host, port, query_params={"full": "1"} + ) + assert creds.user_id == user_id + + class TestSeenNonce: """Tests for _seen_nonce method""" @@ -701,3 +807,125 @@ def test_hawk_credentials_optional_key(self): ) assert credentials.hawk_key is None + + +class TestParseHawkFields: + """Tests for _parse_hawk_fields""" + + def test_parses_all_fields(self, hawk_service): + header = 'Hawk id="abc", ts="123", nonce="xyz", mac="sig=", hash="h", ext="e"' + fields = hawk_service._parse_hawk_fields(header) + assert fields["id"] == "abc" + assert fields["ts"] == "123" + assert fields["nonce"] == "xyz" + assert fields["mac"] == "sig=" + assert fields["hash"] == "h" + assert fields["ext"] == "e" + + def test_empty_header(self, hawk_service): + assert hawk_service._parse_hawk_fields("") == {} + + +class TestCorrectQueryOrder: + """Tests for _correct_query_order — pre-computes MAC to find client's param ordering""" + + def test_finds_correct_order(self, hawk_service, mock_dynamodb_table): + """When API Gateway reorders params, finds the client's original order.""" + import hashlib + import hmac as hmac_mod + + hawk_key = "a" * 64 + hawk_id = base64.urlsafe_b64encode(b"user1:1:9999999999").decode().rstrip("=") + mock_dynamodb_table.get_item.return_value = { + "Item": {"hawk_key": hawk_key, "user_id": "user1", "generation": 1} + } + + ts = "1234567890" + nonce = "abc123" + method = "GET" + host = "storage.example.com" + port = 443 + # Client used: newer=1.09&full=1&limit=1000 + client_resource = "/storage/prefs?newer=1.09&full=1&limit=1000" + normalized = ( + f"hawk.1.header\n{ts}\n{nonce}\n{method}\n" f"{client_resource}\n{host}\n{port}\n\n\n" + ) + client_mac = base64.b64encode( + hmac_mod.new( + hawk_key.encode("ascii"), normalized.encode("ascii"), hashlib.sha256 + ).digest() + ).decode("ascii") + + auth_header = f'Hawk id="{hawk_id}", ts="{ts}", nonce="{nonce}", mac="{client_mac}"' + + # Server has alphabetized path + server_path = "/storage/prefs?full=1&limit=1000&newer=1.09" + query_params = {"full": "1", "limit": "1000", "newer": "1.09"} + + result = hawk_service._correct_query_order( + auth_header, hawk_id, method, server_path, query_params, host, port + ) + assert result == "/storage/prefs?newer=1.09&full=1&limit=1000" + + def test_returns_none_when_already_correct(self, hawk_service, mock_dynamodb_table): + """Returns the current path when the order already matches.""" + import hashlib + import hmac as hmac_mod + + hawk_key = "b" * 64 + hawk_id = base64.urlsafe_b64encode(b"user1:1:9999999999").decode().rstrip("=") + mock_dynamodb_table.get_item.return_value = { + "Item": {"hawk_key": hawk_key, "user_id": "user1", "generation": 1} + } + + ts = "1234567890" + nonce = "def456" + method = "GET" + host = "storage.example.com" + port = 443 + resource = "/storage/clients?full=1&limit=1000" + normalized = f"hawk.1.header\n{ts}\n{nonce}\n{method}\n" f"{resource}\n{host}\n{port}\n\n\n" + client_mac = base64.b64encode( + hmac_mod.new( + hawk_key.encode("ascii"), normalized.encode("ascii"), hashlib.sha256 + ).digest() + ).decode("ascii") + + auth_header = f'Hawk id="{hawk_id}", ts="{ts}", nonce="{nonce}", mac="{client_mac}"' + query_params = {"full": "1", "limit": "1000"} + + result = hawk_service._correct_query_order( + auth_header, hawk_id, method, resource, query_params, host, port + ) + # Returns the resource (it matched on the current order) + assert result == resource + + def test_returns_none_on_missing_header_fields(self, hawk_service): + """Returns None when header can't be parsed.""" + result = hawk_service._correct_query_order( + "BadHeader", "hid", "GET", "/path?a=1&b=2", {"a": "1", "b": "2"}, "h", 443 + ) + assert result is None + + def test_returns_none_on_cache_miss(self, hawk_service, mock_dynamodb_table): + """Returns None when hawk key not in cache (lets mohawk handle it).""" + mock_dynamodb_table.get_item.return_value = {} # No Item + hawk_id = base64.urlsafe_b64encode(b"user1:1:9999999999").decode().rstrip("=") + auth_header = f'Hawk id="{hawk_id}", ts="1", nonce="n", mac="m"' + result = hawk_service._correct_query_order( + auth_header, hawk_id, "GET", "/p?a=1&b=2", {"a": "1", "b": "2"}, "h", 443 + ) + assert result is None + + def test_returns_none_when_no_permutation_matches(self, hawk_service, mock_dynamodb_table): + """Returns None when MAC doesn't match any permutation (tampered request).""" + hawk_key = "c" * 64 + hawk_id = base64.urlsafe_b64encode(b"user1:1:9999999999").decode().rstrip("=") + mock_dynamodb_table.get_item.return_value = { + "Item": {"hawk_key": hawk_key, "user_id": "user1", "generation": 1} + } + auth_header = f'Hawk id="{hawk_id}", ts="1", nonce="n", mac="bogusMAC=="' + result = hawk_service._correct_query_order( + auth_header, hawk_id, "GET", "/p?a=1&b=2", {"a": "1", "b": "2"}, "h", 443 + ) + assert result is None diff --git a/lambda/tests/services/test_storage_hawk_middleware.py b/lambda/tests/services/test_storage_hawk_middleware.py index a59e68c4..0d4ee339 100644 --- a/lambda/tests/services/test_storage_hawk_middleware.py +++ b/lambda/tests/services/test_storage_hawk_middleware.py @@ -67,6 +67,7 @@ def test_success_injects_hawk_uid(self): "/1.5/123/storage/bookmarks", "storage.example.com", 443, + query_params=None, ) # Verify hawk_uid was injected