Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
self.is_dev = app_env == 'dev' or (kms_service is None)
self.private_key = self._load_key(private_key)
self.public_key = self._load_key(public_key)
self.algorithm = TokenAlgorithms.RS256.value if self.is_dev else algorithm.value
self.algorithm = self._resolve_algorithm(kms_service, algorithm, self.is_dev)
self.token_expiry = int(token_expiry)
self.temporary_token_expiry = int(temporary_token_expiry)
self.kms_service = kms_service
Expand All @@ -51,6 +51,30 @@ def _load_key(self, key: str):
key = base64.b64decode(key).decode('ascii')
return key

@staticmethod
def _resolve_algorithm(
kms_service: FloKMS | None,
configured: TokenAlgorithms,
is_dev: bool,
) -> str:
if is_dev:
return TokenAlgorithms.RS256.value
if kms_service is not None:
getter = getattr(kms_service, 'jwt_algorithm', None)
if callable(getter):
return getter()
return configured.value

def _jwt_decode_algorithms(self) -> list[str]:
"""Allow legacy PS256 headers on RS256 (PKCS1) KMS signatures."""
algorithms = [self.algorithm]
if (
self.algorithm == TokenAlgorithms.RS256.value
and TokenAlgorithms.PS256.value not in algorithms
):
algorithms.append(TokenAlgorithms.PS256.value)
return algorithms

def create_token(
self,
sub: str | None = None,
Expand Down Expand Up @@ -137,14 +161,14 @@ def decode_token(self, token: str) -> dict:

is_valid = self.kms_service.verify(message=digest, signature=signature)
if not is_valid:
return {}
raise ValueError('Invalid token signature')

public_key_pem = self.kms_service.get_public_key_pem()

decoded = jwt.decode(
clean_token,
public_key_pem,
algorithms=[self.algorithm],
algorithms=self._jwt_decode_algorithms(),
issuer=self.issuer,
audience=self.audience,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.is_dev = app_env == 'dev' or (kms_service is None)
self.private_key = self._load_key(private_key)
self.public_key = self._load_key(public_key)
self.algorithm = TokenAlgorithms.RS256.value if self.is_dev else algorithm.value
self.algorithm = self._resolve_algorithm(kms_service, algorithm, self.is_dev)
self.token_expiry = int(token_expiry)
self.temporary_token_expiry = int(temporary_token_expiry)
self.kms_service = kms_service
Expand All @@ -49,6 +49,29 @@ def _load_key(self, key: str):
key = base64.b64decode(key).decode('ascii')
return key

@staticmethod
def _resolve_algorithm(
kms_service: FloKMS,
configured: TokenAlgorithms,
is_dev: bool,
) -> str:
if is_dev:
return TokenAlgorithms.RS256.value
if kms_service is not None:
getter = getattr(kms_service, 'jwt_algorithm', None)
if callable(getter):
return getter()
return configured.value

def _jwt_decode_algorithms(self) -> list[str]:
algorithms = [self.algorithm]
if (
self.algorithm == TokenAlgorithms.RS256.value
and TokenAlgorithms.PS256.value not in algorithms
):
algorithms.append(TokenAlgorithms.PS256.value)
return algorithms

def create_token(
self,
sub: str | None = None,
Expand Down Expand Up @@ -118,14 +141,14 @@ def decode_token(self, token: str) -> dict:

is_valid = self.kms_service.verify(message=digest, signature=signature)
if not is_valid:
return {}
raise ValueError('Invalid token signature')

public_key_pem = self.kms_service.get_public_key_pem()

decoded = jwt.decode(
token,
public_key_pem,
algorithms=[self.algorithm],
algorithms=self._jwt_decode_algorithms(),
issuer=self.issuer,
audience=self.audience,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def run_migrations_online() -> None:
)

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(
connection=connection,
target_metadata=target_metadata,
transaction_per_migration=True,
)

with context.begin_transaction():
context.run_migrations()
Expand Down
31 changes: 27 additions & 4 deletions wavefront/server/packages/flo_cloud/flo_cloud/gcp/kms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
gcp_crypto_key = os.getenv('GCP_KMS_CRYPTO_KEY')
gcp_crypto_key_version = os.getenv('GCP_KMS_CRYPTO_KEY_VERSION')

# GCP KMS PKCS#1 v1.5 signing algorithms (JWT alg RS256)
_PKCS1_ALGORITHMS = frozenset(
{
kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_2048_SHA256,
kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_3072_SHA256,
kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_4096_SHA256,
}
)


class GcpKMS(FloKMS):
def __init__(self):
Expand All @@ -40,6 +49,11 @@ def __init__(self):
crypto_key=gcp_crypto_key,
crypto_key_version=gcp_crypto_key_version,
)
public_key = self.kms_client.get_public_key(
request=kms_v1.GetPublicKeyRequest(name=self.key_name)
)
self._key_algorithm = public_key.algorithm
self._uses_pkcs1 = self._key_algorithm in _PKCS1_ALGORITHMS

def encrypt(self, plaintext: str) -> bytes:
request = kms_v1.EncryptRequest(
Expand Down Expand Up @@ -68,20 +82,29 @@ def sign(self, message: bytes, **kwargs) -> bytes:
response = self.kms_client.asymmetric_sign(request=request)
return response.signature

def jwt_algorithm(self) -> str:
"""JWT alg header matching this KMS key (RS256 for PKCS1 keys, PS256 for PSS)."""
return 'RS256' if self._uses_pkcs1 else 'PS256'

def verify(self, message: bytes, signature: bytes, **kwargs) -> bool:
public_key_pem: bytes | str = self.get_public_key_pem(encode=True)
if isinstance(public_key_pem, str):
raise ValueError('Public key is not a bytes object')
rsa_key = serialization.load_pem_public_key(public_key_pem, default_backend())

if self._uses_pkcs1:
verify_padding = padding.PKCS1v15()
else:
verify_padding = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
)

try:
rsa_key.verify( # type: ignore
signature=signature,
data=message,
padding=padding.PSS( # type: ignore
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
padding=verify_padding,
algorithm=utils.Prehashed(hashes.SHA256()), # type: ignore
)
return True
Expand Down
6 changes: 6 additions & 0 deletions wavefront/server/packages/flo_cloud/flo_cloud/kms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ def verify(self, message: bytes, signature: bytes, **kwargs) -> bool:

def get_public_key_pem(self, **kwargs) -> bytes | str:
return self.kms_client.get_public_key_pem(**kwargs)

def jwt_algorithm(self) -> str:
getter = getattr(self.kms_client, 'jwt_algorithm', None)
if callable(getter):
return getter()
return 'PS256'
140 changes: 140 additions & 0 deletions wavefront/server/scripts/test_kms_auth_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#!/usr/bin/env python3
"""
Simulate /floconsole/v1/authenticate token create + require_auth decode (KMS).

Usage:
export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
export GCP_PROJECT_ID=... GCP_LOCATION=... GCP_KMS_KEY_RING=...
export GCP_KMS_CRYPTO_KEY=... GCP_KMS_CRYPTO_KEY_VERSION=...
uv run python scripts/test_kms_auth_flow.py
"""

from __future__ import annotations

import base64
import os
import subprocess
import sys
import tempfile
from uuid import uuid4

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../packages/flo_cloud'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../apps/floconsole'))

from flo_cloud.gcp.kms import GcpKMS
from flo_cloud.kms import FloKmsService
from floconsole.constants.auth import AUTH_ROLE_ID
from floconsole.services.token_service import TokenAlgorithms, TokenService

ISSUER = os.getenv('CONSOLE_JWT_ISSUER', 'https://floconsole.rootflo.ai')
AUDIENCE = os.getenv('CONSOLE_JWT_AUDIENCE', 'https://floconsole.rootflo.ai')
PREFIX = os.getenv('CONSOLE_TOKEN_PREFIX', 'fc_')


def _dummy_pem_keys() -> tuple[str, str]:
with tempfile.NamedTemporaryFile(suffix='.pem', delete=False) as priv:
subprocess.run(
['openssl', 'genrsa', '-out', priv.name, '2048'],

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/usr/bin/env bash
set -euo pipefail

# Verify how openssl resolves in the current environment.
command -v openssl
python - <<'PY'
import shutil
p = shutil.which("openssl")
print(f"shutil.which('openssl') => {p}")
if not p:
    raise SystemExit("openssl missing")
PY

Repository: rootflo/wavefront

Length of output: 43


🏁 Script executed:

# First, locate and examine the test file
find . -name "test_kms_auth_flow.py" -type f

Repository: rootflo/wavefront

Length of output: 110


🏁 Script executed:

cat -n ./wavefront/server/scripts/test_kms_auth_flow.py

Repository: rootflo/wavefront

Length of output: 5589


Resolve openssl to an absolute path before invoking subprocess.

Both calls to openssl on lines 37 and 43 use a partial executable path that relies on ambient PATH, creating a path shadowing vulnerability in compromised environments.

Use shutil.which() to resolve the absolute path at module load time and fail explicitly if openssl is not found:

Suggested fix
+import shutil
 ...
+OPENSSL_BIN = shutil.which('openssl')
+if not OPENSSL_BIN:
+    raise RuntimeError('openssl not found in PATH')
 ...
         subprocess.run(
-            ['openssl', 'genrsa', '-out', priv.name, '2048'],
+            [OPENSSL_BIN, 'genrsa', '-out', priv.name, '2048'],
             check=True,
             capture_output=True,
         )
  ...
     pub_proc = subprocess.run(
-        ['openssl', 'rsa', '-pubout'],
+        [OPENSSL_BIN, 'rsa', '-pubout'],
         input=priv_pem,
         capture_output=True,
         check=True,
     )
🧰 Tools
🪛 Ruff (0.15.13)

[error] 37-37: Starting a process with a partial executable path

(S607)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@wavefront/server/scripts/test_kms_auth_flow.py` at line 37, The subprocess
invocations that use the literal 'openssl' (e.g., the call with args ['openssl',
'genrsa', '-out', priv.name, '2048'] and the similar call later) should resolve
the openssl executable to an absolute path at module load using
shutil.which('openssl') and fail fast if not found; replace the string 'openssl'
in those argument lists with the resolved path (and raise a clear RuntimeError
or SystemExit when shutil.which returns None) so subprocess calls do not rely on
ambient PATH and prevent path‑shadowing vulnerabilities.

check=True,
capture_output=True,
)
priv_pem = open(priv.name, 'rb').read()
Comment on lines +35 to +41

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clean up temporary private key material after reading it.

delete=False leaves the generated private key file on disk and it is never removed. Even for a test script, this is an avoidable secret-leak risk.

Suggested fix
 def _dummy_pem_keys() -> tuple[str, str]:
-    with tempfile.NamedTemporaryFile(suffix='.pem', delete=False) as priv:
+    with tempfile.NamedTemporaryFile(suffix='.pem', delete=False) as priv:
+        priv_path = priv.name
         subprocess.run(
-            ['openssl', 'genrsa', '-out', priv.name, '2048'],
+            ['openssl', 'genrsa', '-out', priv_path, '2048'],
             check=True,
             capture_output=True,
         )
-        priv_pem = open(priv.name, 'rb').read()
+    try:
+        with open(priv_path, 'rb') as f:
+            priv_pem = f.read()
+    finally:
+        os.unlink(priv_path)
🧰 Tools
🪛 Ruff (0.15.13)

[error] 36-36: subprocess call: check for execution of untrusted input

(S603)


[error] 37-37: Starting a process with a partial executable path

(S607)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@wavefront/server/scripts/test_kms_auth_flow.py` around lines 35 - 41, The
temporary private key file is created with tempfile.NamedTemporaryFile(...,
delete=False) and never removed, leaving secret material on disk; modify the
test_kms_auth_flow.py flow to either use delete=True or explicitly remove the
file after reading priv_pem (e.g., call os.remove(priv.name) or use a context
where the file is created without delete=False), ensuring cleanup happens even
on errors (use try/finally or contextmanager) and reference the
tempfile.NamedTemporaryFile invocation, the priv variable, and the priv_pem read
to locate where to add the removal.

pub_proc = subprocess.run(
['openssl', 'rsa', '-pubout'],
input=priv_pem,
capture_output=True,
check=True,
)
return base64.b64encode(priv_pem).decode(), base64.b64encode(
pub_proc.stdout
).decode()


def _simulate_require_auth(decoded: dict) -> str | None:
"""Mirror floconsole require_auth checks after decode_token."""
if 'session_id' not in decoded:
return 'Invalid token: missing session_id'
if 'role_id' not in decoded or decoded['role_id'] != AUTH_ROLE_ID:
return 'Invalid token: Not the console user'
return None


def main() -> int:
print('=== KMS auth flow test (create_token + decode_token) ===\n')

for var in (
'GCP_PROJECT_ID',
'GCP_LOCATION',
'GCP_KMS_KEY_RING',
'GCP_KMS_CRYPTO_KEY',
'GCP_KMS_CRYPTO_KEY_VERSION',
'GOOGLE_APPLICATION_CREDENTIALS',
):
print(f' {var}={os.environ.get(var, "<not set>")}')

print('\n--- Step 1: Init KMS (same as ApplicationContainer) ---')
kms = FloKmsService(cloud_provider='gcp')
gcp: GcpKMS = kms.kms_client # type: ignore[assignment]
print(f' KMS key: {gcp.key_name}')
print(f' jwt_algorithm(): {kms.jwt_algorithm()}')
print(f' uses_pkcs1: {gcp._uses_pkcs1}')

priv, pub = _dummy_pem_keys()
token_service = TokenService(
private_key=priv,
public_key=pub,
kms_service=kms,
algorithm=TokenAlgorithms.PS256,
app_env='production',
token_prefix=PREFIX,
issuer=ISSUER,
audience=AUDIENCE,
)
print('\n--- Step 2: TokenService (production / KMS) ---')
print(f' is_dev={token_service.is_dev}')
print(f' algorithm={token_service.algorithm}')

session_id = str(uuid4())
user_id = str(uuid4())
print('\n--- Step 3: create_token (POST /authenticate) ---')
token = token_service.create_token(
sub='admin@rootflo.ai',
user_id=user_id,
role_id=AUTH_ROLE_ID,
payload={'session_id': session_id},
)
print(f' token length={len(token)}')
print(f' prefix ok={token.startswith(PREFIX)}')
header_alg = __import__('json').loads(
base64.urlsafe_b64decode(token[len(PREFIX) :].split('.')[0] + '==')
)['alg']
print(f' JWT header alg={header_alg}')
Comment on lines +107 to +111

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fail fast when token prefix validation fails.

The script currently prints prefix ok=False but still proceeds. That can hide regressions in prefix handling and produce misleading pass/fail outcomes.

Suggested fix
     print(f'  token length={len(token)}')
-    print(f'  prefix ok={token.startswith(PREFIX)}')
+    prefix_ok = token.startswith(PREFIX)
+    print(f'  prefix ok={prefix_ok}')
+    if not prefix_ok:
+        print('  FAIL: token prefix mismatch')
+        return 1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f' prefix ok={token.startswith(PREFIX)}')
header_alg = __import__('json').loads(
base64.urlsafe_b64decode(token[len(PREFIX) :].split('.')[0] + '==')
)['alg']
print(f' JWT header alg={header_alg}')
prefix_ok = token.startswith(PREFIX)
print(f' prefix ok={prefix_ok}')
if not prefix_ok:
print(' FAIL: token prefix mismatch')
return 1
header_alg = __import__('json').loads(
base64.urlsafe_b64decode(token[len(PREFIX) :].split('.')[0] + '==')
)['alg']
print(f' JWT header alg={header_alg}')
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@wavefront/server/scripts/test_kms_auth_flow.py` around lines 107 - 111, The
code prints the prefix check but continues even when the token doesn't start
with PREFIX; update the block around the print and header_alg extraction to fail
fast: after computing/reading token and PREFIX, check token.startswith(PREFIX)
and if false immediately exit with a non-zero status or raise an exception
(e.g., SystemExit or ValueError) so the script stops instead of decoding the
remainder; reference the variables token, PREFIX and the header_alg extraction
expression to locate where to insert the immediate failure.


print('\n--- Step 4: decode_token (require_auth middleware) ---')
try:
decoded = token_service.decode_token(token)
except ValueError as e:
print(f' FAIL ValueError: {e}')
return 1
except Exception as e:
print(f' FAIL {type(e).__name__}: {e}')
return 1

print(f' decoded session_id={decoded.get("session_id")}')
print(f' decoded role_id={decoded.get("role_id")}')
print(f' decoded iss={decoded.get("iss")}')

err = _simulate_require_auth(decoded)
if err:
print('\n--- Step 5: require_auth ---')
print(f' FAIL: {err}')
return 1

print('\n--- Step 5: require_auth ---')
print(' OK: token would be accepted')
print('\n=== PASS: full KMS create + validate flow ===')
return 0


if __name__ == '__main__':
sys.exit(main())
Loading