From 7214900ebe1062b524641aba3c8f54fed1c569ea Mon Sep 17 00:00:00 2001 From: vizsatiz Date: Tue, 19 May 2026 09:51:59 +0530 Subject: [PATCH 1/2] fix for floconsole and kms --- .../floconsole/services/token_service.py | 30 +++- .../auth_module/services/token_service.py | 29 +++- .../packages/flo_cloud/flo_cloud/gcp/kms.py | 31 +++- .../packages/flo_cloud/flo_cloud/kms.py | 6 + .../server/scripts/test_kms_auth_flow.py | 140 ++++++++++++++++++ 5 files changed, 226 insertions(+), 10 deletions(-) create mode 100644 wavefront/server/scripts/test_kms_auth_flow.py diff --git a/wavefront/server/apps/floconsole/floconsole/services/token_service.py b/wavefront/server/apps/floconsole/floconsole/services/token_service.py index b7f28c6a..61d0205e 100644 --- a/wavefront/server/apps/floconsole/floconsole/services/token_service.py +++ b/wavefront/server/apps/floconsole/floconsole/services/token_service.py @@ -39,7 +39,7 @@ def __init__( self.is_dev = app_env == 'dev' or (kms_service is None) self.private_key = self._load_key(private_key) self.public_key = self._load_key(public_key) - self.algorithm = TokenAlgorithms.RS256.value if self.is_dev else algorithm.value + self.algorithm = self._resolve_algorithm(kms_service, algorithm, self.is_dev) self.token_expiry = int(token_expiry) self.temporary_token_expiry = int(temporary_token_expiry) self.kms_service = kms_service @@ -51,6 +51,30 @@ def _load_key(self, key: str): key = base64.b64decode(key).decode('ascii') return key + @staticmethod + def _resolve_algorithm( + kms_service: FloKMS | None, + configured: TokenAlgorithms, + is_dev: bool, + ) -> str: + if is_dev: + return TokenAlgorithms.RS256.value + if kms_service is not None: + getter = getattr(kms_service, 'jwt_algorithm', None) + if callable(getter): + return getter() + return configured.value + + def _jwt_decode_algorithms(self) -> list[str]: + """Allow legacy PS256 headers on RS256 (PKCS1) KMS signatures.""" + algorithms = [self.algorithm] + if ( + self.algorithm == TokenAlgorithms.RS256.value + and TokenAlgorithms.PS256.value not in algorithms + ): + algorithms.append(TokenAlgorithms.PS256.value) + return algorithms + def create_token( self, sub: str | None = None, @@ -137,14 +161,14 @@ def decode_token(self, token: str) -> dict: is_valid = self.kms_service.verify(message=digest, signature=signature) if not is_valid: - return {} + raise ValueError('Invalid token signature') public_key_pem = self.kms_service.get_public_key_pem() decoded = jwt.decode( clean_token, public_key_pem, - algorithms=[self.algorithm], + algorithms=self._jwt_decode_algorithms(), issuer=self.issuer, audience=self.audience, ) diff --git a/wavefront/server/modules/auth_module/auth_module/services/token_service.py b/wavefront/server/modules/auth_module/auth_module/services/token_service.py index 33e94f9d..3ab45719 100644 --- a/wavefront/server/modules/auth_module/auth_module/services/token_service.py +++ b/wavefront/server/modules/auth_module/auth_module/services/token_service.py @@ -38,7 +38,7 @@ def __init__( self.is_dev = app_env == 'dev' or (kms_service is None) self.private_key = self._load_key(private_key) self.public_key = self._load_key(public_key) - self.algorithm = TokenAlgorithms.RS256.value if self.is_dev else algorithm.value + self.algorithm = self._resolve_algorithm(kms_service, algorithm, self.is_dev) self.token_expiry = int(token_expiry) self.temporary_token_expiry = int(temporary_token_expiry) self.kms_service = kms_service @@ -49,6 +49,29 @@ def _load_key(self, key: str): key = base64.b64decode(key).decode('ascii') return key + @staticmethod + def _resolve_algorithm( + kms_service: FloKMS, + configured: TokenAlgorithms, + is_dev: bool, + ) -> str: + if is_dev: + return TokenAlgorithms.RS256.value + if kms_service is not None: + getter = getattr(kms_service, 'jwt_algorithm', None) + if callable(getter): + return getter() + return configured.value + + def _jwt_decode_algorithms(self) -> list[str]: + algorithms = [self.algorithm] + if ( + self.algorithm == TokenAlgorithms.RS256.value + and TokenAlgorithms.PS256.value not in algorithms + ): + algorithms.append(TokenAlgorithms.PS256.value) + return algorithms + def create_token( self, sub: str | None = None, @@ -118,14 +141,14 @@ def decode_token(self, token: str) -> dict: is_valid = self.kms_service.verify(message=digest, signature=signature) if not is_valid: - return {} + raise ValueError('Invalid token signature') public_key_pem = self.kms_service.get_public_key_pem() decoded = jwt.decode( token, public_key_pem, - algorithms=[self.algorithm], + algorithms=self._jwt_decode_algorithms(), issuer=self.issuer, audience=self.audience, ) diff --git a/wavefront/server/packages/flo_cloud/flo_cloud/gcp/kms.py b/wavefront/server/packages/flo_cloud/flo_cloud/gcp/kms.py index c51dc731..7f1f6e85 100644 --- a/wavefront/server/packages/flo_cloud/flo_cloud/gcp/kms.py +++ b/wavefront/server/packages/flo_cloud/flo_cloud/gcp/kms.py @@ -16,6 +16,15 @@ gcp_crypto_key = os.getenv('GCP_KMS_CRYPTO_KEY') gcp_crypto_key_version = os.getenv('GCP_KMS_CRYPTO_KEY_VERSION') +# GCP KMS PKCS#1 v1.5 signing algorithms (JWT alg RS256) +_PKCS1_ALGORITHMS = frozenset( + { + kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_2048_SHA256, + kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_3072_SHA256, + kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_4096_SHA256, + } +) + class GcpKMS(FloKMS): def __init__(self): @@ -40,6 +49,11 @@ def __init__(self): crypto_key=gcp_crypto_key, crypto_key_version=gcp_crypto_key_version, ) + public_key = self.kms_client.get_public_key( + request=kms_v1.GetPublicKeyRequest(name=self.key_name) + ) + self._key_algorithm = public_key.algorithm + self._uses_pkcs1 = self._key_algorithm in _PKCS1_ALGORITHMS def encrypt(self, plaintext: str) -> bytes: request = kms_v1.EncryptRequest( @@ -68,20 +82,29 @@ def sign(self, message: bytes, **kwargs) -> bytes: response = self.kms_client.asymmetric_sign(request=request) return response.signature + def jwt_algorithm(self) -> str: + """JWT alg header matching this KMS key (RS256 for PKCS1 keys, PS256 for PSS).""" + return 'RS256' if self._uses_pkcs1 else 'PS256' + def verify(self, message: bytes, signature: bytes, **kwargs) -> bool: public_key_pem: bytes | str = self.get_public_key_pem(encode=True) if isinstance(public_key_pem, str): raise ValueError('Public key is not a bytes object') rsa_key = serialization.load_pem_public_key(public_key_pem, default_backend()) + if self._uses_pkcs1: + verify_padding = padding.PKCS1v15() + else: + verify_padding = padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ) + try: rsa_key.verify( # type: ignore signature=signature, data=message, - padding=padding.PSS( # type: ignore - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH, - ), + padding=verify_padding, algorithm=utils.Prehashed(hashes.SHA256()), # type: ignore ) return True diff --git a/wavefront/server/packages/flo_cloud/flo_cloud/kms.py b/wavefront/server/packages/flo_cloud/flo_cloud/kms.py index c77eae9a..966559e1 100644 --- a/wavefront/server/packages/flo_cloud/flo_cloud/kms.py +++ b/wavefront/server/packages/flo_cloud/flo_cloud/kms.py @@ -37,3 +37,9 @@ def verify(self, message: bytes, signature: bytes, **kwargs) -> bool: def get_public_key_pem(self, **kwargs) -> bytes | str: return self.kms_client.get_public_key_pem(**kwargs) + + def jwt_algorithm(self) -> str: + getter = getattr(self.kms_client, 'jwt_algorithm', None) + if callable(getter): + return getter() + return 'PS256' diff --git a/wavefront/server/scripts/test_kms_auth_flow.py b/wavefront/server/scripts/test_kms_auth_flow.py new file mode 100644 index 00000000..8f51e102 --- /dev/null +++ b/wavefront/server/scripts/test_kms_auth_flow.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +""" +Simulate /floconsole/v1/authenticate token create + require_auth decode (KMS). + +Usage: + export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json + export GCP_PROJECT_ID=... GCP_LOCATION=... GCP_KMS_KEY_RING=... + export GCP_KMS_CRYPTO_KEY=... GCP_KMS_CRYPTO_KEY_VERSION=... + uv run python scripts/test_kms_auth_flow.py +""" + +from __future__ import annotations + +import base64 +import os +import subprocess +import sys +import tempfile +from uuid import uuid4 + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../packages/flo_cloud')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../apps/floconsole')) + +from flo_cloud.gcp.kms import GcpKMS +from flo_cloud.kms import FloKmsService +from floconsole.constants.auth import AUTH_ROLE_ID +from floconsole.services.token_service import TokenAlgorithms, TokenService + +ISSUER = os.getenv('CONSOLE_JWT_ISSUER', 'https://floconsole.rootflo.ai') +AUDIENCE = os.getenv('CONSOLE_JWT_AUDIENCE', 'https://floconsole.rootflo.ai') +PREFIX = os.getenv('CONSOLE_TOKEN_PREFIX', 'fc_') + + +def _dummy_pem_keys() -> tuple[str, str]: + with tempfile.NamedTemporaryFile(suffix='.pem', delete=False) as priv: + subprocess.run( + ['openssl', 'genrsa', '-out', priv.name, '2048'], + check=True, + capture_output=True, + ) + priv_pem = open(priv.name, 'rb').read() + pub_proc = subprocess.run( + ['openssl', 'rsa', '-pubout'], + input=priv_pem, + capture_output=True, + check=True, + ) + return base64.b64encode(priv_pem).decode(), base64.b64encode( + pub_proc.stdout + ).decode() + + +def _simulate_require_auth(decoded: dict) -> str | None: + """Mirror floconsole require_auth checks after decode_token.""" + if 'session_id' not in decoded: + return 'Invalid token: missing session_id' + if 'role_id' not in decoded or decoded['role_id'] != AUTH_ROLE_ID: + return 'Invalid token: Not the console user' + return None + + +def main() -> int: + print('=== KMS auth flow test (create_token + decode_token) ===\n') + + for var in ( + 'GCP_PROJECT_ID', + 'GCP_LOCATION', + 'GCP_KMS_KEY_RING', + 'GCP_KMS_CRYPTO_KEY', + 'GCP_KMS_CRYPTO_KEY_VERSION', + 'GOOGLE_APPLICATION_CREDENTIALS', + ): + print(f' {var}={os.environ.get(var, "")}') + + print('\n--- Step 1: Init KMS (same as ApplicationContainer) ---') + kms = FloKmsService(cloud_provider='gcp') + gcp: GcpKMS = kms.kms_client # type: ignore[assignment] + print(f' KMS key: {gcp.key_name}') + print(f' jwt_algorithm(): {kms.jwt_algorithm()}') + print(f' uses_pkcs1: {gcp._uses_pkcs1}') + + priv, pub = _dummy_pem_keys() + token_service = TokenService( + private_key=priv, + public_key=pub, + kms_service=kms, + algorithm=TokenAlgorithms.PS256, + app_env='production', + token_prefix=PREFIX, + issuer=ISSUER, + audience=AUDIENCE, + ) + print('\n--- Step 2: TokenService (production / KMS) ---') + print(f' is_dev={token_service.is_dev}') + print(f' algorithm={token_service.algorithm}') + + session_id = str(uuid4()) + user_id = str(uuid4()) + print('\n--- Step 3: create_token (POST /authenticate) ---') + token = token_service.create_token( + sub='admin@rootflo.ai', + user_id=user_id, + role_id=AUTH_ROLE_ID, + payload={'session_id': session_id}, + ) + print(f' token length={len(token)}') + print(f' prefix ok={token.startswith(PREFIX)}') + header_alg = __import__('json').loads( + base64.urlsafe_b64decode(token[len(PREFIX) :].split('.')[0] + '==') + )['alg'] + print(f' JWT header alg={header_alg}') + + print('\n--- Step 4: decode_token (require_auth middleware) ---') + try: + decoded = token_service.decode_token(token) + except ValueError as e: + print(f' FAIL ValueError: {e}') + return 1 + except Exception as e: + print(f' FAIL {type(e).__name__}: {e}') + return 1 + + print(f' decoded session_id={decoded.get("session_id")}') + print(f' decoded role_id={decoded.get("role_id")}') + print(f' decoded iss={decoded.get("iss")}') + + err = _simulate_require_auth(decoded) + if err: + print('\n--- Step 5: require_auth ---') + print(f' FAIL: {err}') + return 1 + + print('\n--- Step 5: require_auth ---') + print(' OK: token would be accepted') + print('\n=== PASS: full KMS create + validate flow ===') + return 0 + + +if __name__ == '__main__': + sys.exit(main()) From 59cb84354ba4922b9d9038f0f15eaaa88a19526d Mon Sep 17 00:00:00 2001 From: vizsatiz Date: Thu, 21 May 2026 13:00:26 +0530 Subject: [PATCH 2/2] fix migration issue --- .../modules/db_repo_module/db_repo_module/alembic/env.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/alembic/env.py b/wavefront/server/modules/db_repo_module/db_repo_module/alembic/env.py index 676e98e1..f77fbce2 100644 --- a/wavefront/server/modules/db_repo_module/db_repo_module/alembic/env.py +++ b/wavefront/server/modules/db_repo_module/db_repo_module/alembic/env.py @@ -139,7 +139,11 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure( + connection=connection, + target_metadata=target_metadata, + transaction_per_migration=True, + ) with context.begin_transaction(): context.run_migrations()