From b6ec8624867e65083e546deec9210cfc2a01f020 Mon Sep 17 00:00:00 2001 From: Izzat Nisfer Date: Thu, 7 May 2026 15:44:57 +0530 Subject: [PATCH 1/6] feat: implement real-time alert system with Socket.IO --- requirements.txt | 4 + src/d2/api/app.py | 227 ++++++++++++++++++++++++++++++- src/d2/api/routes/alerts.py | 53 +++++++- src/d2/config.py | 6 + src/d2/db/repository.py | 18 ++- src/d2/observability/metrics.py | 17 +++ src/d2/streaming/worker.py | 6 + tests/unit/test_alert_payload.py | 23 ++++ 8 files changed, 338 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_alert_payload.py diff --git a/requirements.txt b/requirements.txt index 7703e8f..a999dd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,10 @@ vitaldb>=1.4 # Streaming (member: Kafka track) kafka-python>=2.0 +aiokafka>=0.10 + +# Real-time alerts (Socket.IO) +python-socketio>=5.11 # Time-series (InfluxDB 2.x — worker vitals writes) influxdb-client>=1.38,<2 diff --git a/src/d2/api/app.py b/src/d2/api/app.py index 6eae59d..5c9001b 100644 --- a/src/d2/api/app.py +++ b/src/d2/api/app.py @@ -8,22 +8,33 @@ from __future__ import annotations -import logging -import socket import time +import ssl from contextlib import asynccontextmanager +from typing import Any -from fastapi import FastAPI +import socketio +from aiokafka import AIOKafkaConsumer +from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.requests import Request from fastapi.responses import JSONResponse, Response from prometheus_client import CONTENT_TYPE_LATEST, generate_latest from sqlalchemy.exc import OperationalError +from starlette.middleware.sessions import SessionMiddleware -from d2.api.routes import health, vitals, vitals_routes, patients, alerts, devices, reports -from d2.config import database_hostname, get_settings -from d2.db.connection import dispose_engine -from d2.observability.metrics import API_ERRORS_TOTAL, API_REQUESTS_TOTAL, API_RESPONSE_SECONDS +from d2.api.deps import get_current_user +from d2.api.id_utils import format_patient_slug +from d2.api.routes import auth, alerts, devices, health, patients, reports, vitals, vitals_routes +from d2.api.schemas.alert import AlertPayload +from d2.api.session import InMemorySessionStore +from d2.config import database_hostname, get_settings, resolve_repo_relative_path +from d2.db import repository +from d2.db.connection import dispose_engine, get_session_maker +from d2.observability.metrics import ( + ALERTS_DELIVERED, ALERTS_DELIVERY_FAILED, + API_ERRORS_TOTAL, API_REQUESTS_TOTAL, API_RESPONSE_SECONDS, +) log = logging.getLogger(__name__) @@ -35,6 +46,174 @@ } +def _split_origins(raw: str | None) -> list[str]: + if not raw: + return [] + return [origin.strip() for origin in raw.split(",") if origin.strip()] + + +def _extract_auth_token(auth: Any) -> str | None: + if not isinstance(auth, dict): + return None + token = auth.get("token") + if not token: + return None + token_str = str(token).strip() + if token_str.lower().startswith("bearer "): + return token_str[7:].strip() + return token_str + + +def _severity_from_alert(alert: dict[str, Any]) -> str: + if alert.get("is_falling"): + return "critical" + for key in ("hr_status", "spo2_status", "temp_status"): + if str(alert.get(key) or "").lower() == "critical": + return "critical" + for key in ("hr_status", "spo2_status", "temp_status"): + if str(alert.get(key) or "").lower() == "warning": + return "warning" + return "info" + + +def _message_from_alert(alert: dict[str, Any]) -> str: + reason = str(alert.get("reason") or "").strip() + return reason if reason else "threshold_breach" + + +async def _resolve_patient_context( + session: Any, + alert: dict[str, Any], +) -> tuple[int | None, str]: + patient_id = alert.get("patient_id") + if not patient_id and alert.get("device_id"): + patient_id = await repository.fetch_assigned_patient_for_device( + session, str(alert["device_id"]) + ) + + patient_name = "Unknown patient" + if patient_id: + row = await repository.fetch_patient_by_id(session, int(patient_id)) + if row and row.get("full_name"): + patient_name = str(row["full_name"]).strip() + + return (int(patient_id) if patient_id else None, patient_name) + + +def _build_alert_payload( + alert: dict[str, Any], + patient_id: int | None, + patient_name: str, +) -> AlertPayload: + patient_slug = format_patient_slug(patient_id or 0) + return AlertPayload( + patientId=patient_slug, + patientName=patient_name, + message=_message_from_alert(alert), + severity=_severity_from_alert(alert), + ) + + +def _build_kafka_ssl_context(settings: Any) -> ssl.SSLContext | None: + if settings.kafka_security_protocol.upper() == "PLAINTEXT": + return None + context = ssl.create_default_context() + ca = resolve_repo_relative_path(settings.kafka_ssl_cafile) + cert = resolve_repo_relative_path(settings.kafka_ssl_certfile) + key = resolve_repo_relative_path(settings.kafka_ssl_keyfile) + if ca: + context.load_verify_locations(cafile=ca) + if cert and key: + context.load_cert_chain(certfile=cert, keyfile=key) + return context + + +def _register_socket_handlers( + socket_server: socketio.AsyncServer, + *, + require_auth: bool, +) -> None: + @socket_server.event + async def connect(sid, environ, auth): + token = _extract_auth_token(auth) + if require_auth: + if not token: + raise socketio.exceptions.ConnectionRefusedError("auth required") + if not InMemorySessionStore.get_session(token): + raise socketio.exceptions.ConnectionRefusedError("invalid session") + log.info("[Socket] client connected sid=%s", sid) + + @socket_server.event + async def disconnect(sid): + log.info("[Socket] client disconnected sid=%s", sid) + + @socket_server.on("join_room") + async def join_room(sid, payload): + room_id = None + if isinstance(payload, dict): + room_id = payload.get("roomId") or payload.get("room_id") + if room_id: + socket_server.enter_room(sid, str(room_id)) + log.info("[Socket] %s joined room=%s", sid, room_id) + + @socket_server.on("leave_room") + async def leave_room(sid, payload): + room_id = None + if isinstance(payload, dict): + room_id = payload.get("roomId") or payload.get("room_id") + if room_id: + socket_server.leave_room(sid, str(room_id)) + log.info("[Socket] %s left room=%s", sid, room_id) + + +async def _consume_alerts(socket_server: socketio.AsyncServer, settings: Any) -> None: + ssl_context = _build_kafka_ssl_context(settings) + session_maker = get_session_maker(settings) + retry_delay = 5 + + while True: + consumer = AIOKafkaConsumer( + settings.kafka_topic_alerts, + bootstrap_servers=settings.kafka_bootstrap_servers.split(","), + group_id="d2-socketio-broadcaster", + value_deserializer=lambda m: json.loads(m.decode("utf-8")), + enable_auto_commit=True, + security_protocol=settings.kafka_security_protocol, + ssl_context=ssl_context, + ) + try: + await consumer.start() + log.info("Socket.IO alert consumer started topic=%s", settings.kafka_topic_alerts) + async for msg in consumer: + alert = msg.value + async with session_maker() as session: + patient_id, patient_name = await _resolve_patient_context(session, alert) + payload = _build_alert_payload(alert, patient_id, patient_name) + severity = payload.severity + + try: + await socket_server.emit("alert", payload.model_dump()) + ALERTS_DELIVERED.labels(severity=severity).inc() + log.info( + "[Socket] alert broadcast severity=%s patient_id=%s", + severity, + patient_id, + ) + except Exception: + ALERTS_DELIVERY_FAILED.labels(severity=severity).inc() + log.exception("[Socket] alert broadcast failed") + except asyncio.CancelledError: + break + except Exception: + log.exception("Socket.IO alert consumer failed; retrying in %ss", retry_delay) + await asyncio.sleep(retry_delay) + finally: + try: + await consumer.stop() + except Exception: + log.exception("Socket.IO alert consumer stop failed") + + @asynccontextmanager async def _lifespan(app: FastAPI): """Log which DB the process uses; release engine on shutdown.""" @@ -44,11 +223,26 @@ async def _lifespan(app: FastAPI): database_hostname(settings.database_url), settings.pg_ssl_required, ) + + socket_server: socketio.AsyncServer | None = getattr(app.state, "socketio_server", None) + kafka_task: asyncio.Task | None = None + if socket_server is not None: + kafka_task = asyncio.create_task(_consume_alerts(socket_server, settings)) + yield + + if kafka_task is not None: + kafka_task.cancel() + try: + await kafka_task + except asyncio.CancelledError: + pass await dispose_engine() def create_app() -> FastAPI: + settings = get_settings() + app = FastAPI( title="D2 Intelligence API", description="Data & Intelligence layer — vitals scoring, history, alerts integration.", @@ -56,6 +250,22 @@ def create_app() -> FastAPI: lifespan=_lifespan, ) + socketio_origins = _split_origins(settings.socketio_allow_origins or settings.cors_allow_origins) + socket_server = socketio.AsyncServer( + async_mode="asgi", + cors_allowed_origins=socketio_origins or ["*"], + ) + _register_socket_handlers(socket_server, require_auth=settings.socketio_require_auth) + app.state.socketio_server = socket_server + + app.add_middleware( + SessionMiddleware, + secret_key=settings.session_secret_key, + https_only=False, + same_site="lax", + max_age=600, + ) + # CORS — allow all origins for dev / demo; restrict to D3 origin in production. app.add_middleware( CORSMiddleware, @@ -88,6 +298,9 @@ async def _metrics_middleware(request: Request, call_next): status=str(status_code), ).inc() + # Mount Socket.IO server on /socket.io (matches frontend default path). + app.mount("/socket.io", socketio.ASGIApp(socket_server)) + @app.get("/metrics", include_in_schema=False) def metrics() -> Response: """Prometheus scrape endpoint.""" diff --git a/src/d2/api/routes/alerts.py b/src/d2/api/routes/alerts.py index 0b6ab2b..703931e 100644 --- a/src/d2/api/routes/alerts.py +++ b/src/d2/api/routes/alerts.py @@ -2,26 +2,59 @@ from __future__ import annotations -from fastapi import APIRouter, HTTPException +import logging +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import text from d2.api.converters import row_to_live_alert -from d2.api.deps import DbSession +from d2.api.deps import DbSession, require_roles from d2.api.schemas import AcknowledgeAlertResponse, LiveAlert from d2.db import repository +from d2.observability.metrics import ALERTS_ACKNOWLEDGED + +log = logging.getLogger(__name__) router = APIRouter() -@router.get("/alerts", response_model=list[LiveAlert], tags=["alerts"]) -async def get_alerts(session: DbSession, limit: int = 500) -> list[LiveAlert]: +@router.get( + "/alerts", + response_model=list[LiveAlert], + tags=["alerts"], + dependencies=[Depends(require_roles("doctor", "admin", "patient"))], +) +async def get_alerts( + session: DbSession, + user: dict = Depends(require_roles("doctor", "admin", "patient")), + patient_id: int | None = None, + limit: int = 500 +) -> list[LiveAlert]: """All alerts, newest first (see ``alerts`` + ``patients`` join).""" - rows = await repository.list_live_alerts(session, limit=limit) + role = str(user.get("role") or "").lower() + + # If the user is a patient, they can only see their own alerts + if "patient" in role: + auth_subject_id = user.get("auth_subject_id") + if not auth_subject_id: + raise HTTPException(status_code=403, detail="Patient profile not linked") + + pid = await repository.fetch_patient_id_by_auth_subject(session, auth_subject_id) + if not pid: + raise HTTPException(status_code=403, detail="Patient profile not found") + + patient_id = pid + + rows = await repository.list_live_alerts(session, limit=limit, patient_id=patient_id) return [row_to_live_alert(dict(r)) for r in rows] -@router.patch("/alerts/{alert_id}/acknowledge", response_model=AcknowledgeAlertResponse, tags=["alerts"]) -async def acknowledge_alert(alert_id: str, session: DbSession) -> AcknowledgeAlertResponse: +@router.patch( + "/alerts/{alert_id}/acknowledge", + response_model=AcknowledgeAlertResponse, + tags=["alerts"], + dependencies=[Depends(require_roles("doctor", "admin"))], +) +async def acknowledge_alert(alert_id: str, session: DbSession, user: dict = Depends(require_roles("doctor", "admin"))) -> AcknowledgeAlertResponse: try: aid = int(alert_id) except ValueError as exc: @@ -34,6 +67,12 @@ async def acknowledge_alert(alert_id: str, session: DbSession) -> AcknowledgeAle raise HTTPException(status_code=404, detail=f"Alert with id '{alert_id}' not found") await repository.acknowledge_alert(session, aid) + ALERTS_ACKNOWLEDGED.inc() + + # Log acknowledgment + username = user.get("preferred_username") or user.get("email") or "unknown_user" + log.info(f"Alert {alert_id} acknowledged by user: {username} (auth_subject_id: {user.get('auth_subject_id')})") + return AcknowledgeAlertResponse( success=True, message=f"Alert {alert_id} acknowledged successfully" ) diff --git a/src/d2/config.py b/src/d2/config.py index 4fcd05d..8a2d57b 100644 --- a/src/d2/config.py +++ b/src/d2/config.py @@ -90,6 +90,12 @@ def influx_configured(self) -> bool: # Optional Keycloak checks (D4 will supply values) keycloak_well_known_url: str | None = None + # Auth — OIDC/PKCE session management (session_id stored in secure cookie) + session_secret_key: str = "dev-key-not-secure" # MUST override in production via SESSION_SECRET_KEY env var + cors_allow_origins: str = "http://localhost:3000" # Override via CORS_ALLOW_ORIGINS (comma-separated) + socketio_require_auth: bool = True + socketio_allow_origins: str | None = None # Override via SOCKETIO_ALLOW_ORIGINS env var + @model_validator(mode="after") def ensure_asyncpg_driver_in_database_url(self) -> Self: """SQLAlchemy AsyncEngine requires asyncpg. Supabase / dashboards often paste diff --git a/src/d2/db/repository.py b/src/d2/db/repository.py index 2b92036..f7eccf4 100644 --- a/src/d2/db/repository.py +++ b/src/d2/db/repository.py @@ -322,6 +322,7 @@ async def list_alerts( FROM alerts a LEFT JOIN patients p ON p.id = a.patient_id LEFT JOIN users u ON u.id = p.user_id + WHERE (CAST(:patient_id AS INTEGER) IS NULL OR a.patient_id = CAST(:patient_id AS INTEGER)) ORDER BY a.ts DESC LIMIT :limit """ @@ -505,11 +506,24 @@ async def fetch_vitals_for_patient_range( return [dict(r) for r in result.mappings().all()] -async def list_live_alerts(session: AsyncSession, limit: int = 500) -> list[dict[str, Any]]: - result = await session.execute(_LIST_ALERTS_LIVE, {"limit": limit}) +async def list_live_alerts(session: AsyncSession, limit: int = 500, patient_id: int | None = None) -> list[dict[str, Any]]: + result = await session.execute(_LIST_ALERTS_LIVE, {"limit": limit, "patient_id": patient_id}) return [dict(r) for r in result.mappings().all()] +_FETCH_PATIENT_ID_BY_AUTH_SUBJECT = text( + """ + SELECT id + FROM patients + WHERE auth_subject_id = :auth_subject_id + """ +) + +async def fetch_patient_id_by_auth_subject(session: AsyncSession, auth_subject_id: int) -> int | None: + result = await session.execute(_FETCH_PATIENT_ID_BY_AUTH_SUBJECT, {"auth_subject_id": auth_subject_id}) + row = result.scalar_one_or_none() + return int(row) if row else None + async def acknowledge_alert(session: AsyncSession, alert_id: int) -> bool: result = await session.execute(_ACK_ALERT, {"id": alert_id}) row = result.first() diff --git a/src/d2/observability/metrics.py b/src/d2/observability/metrics.py index 1021e08..d6d15f0 100644 --- a/src/d2/observability/metrics.py +++ b/src/d2/observability/metrics.py @@ -74,6 +74,23 @@ labelnames=("status",), ) +ALERTS_DELIVERED = Counter( + "d2_alerts_delivered_total", + "Alerts broadcast to Socket.IO clients.", + labelnames=("severity",), +) + +ALERTS_DELIVERY_FAILED = Counter( + "d2_alerts_delivery_failed_total", + "Alert broadcast failures (Socket.IO).", + labelnames=("severity",), +) + +ALERTS_ACKNOWLEDGED = Counter( + "d2_alerts_acknowledged_total", + "Alerts acknowledged by a clinician.", +) + PROCESSING_SECONDS = Histogram( "d2_processing_seconds", "End-to-end time to handle one D1 message (parse -> features -> score -> persist -> publish).", diff --git a/src/d2/streaming/worker.py b/src/d2/streaming/worker.py index 5335dcd..0392f18 100644 --- a/src/d2/streaming/worker.py +++ b/src/d2/streaming/worker.py @@ -497,6 +497,12 @@ def _process_one( if db_row.get("overall_status") == "critical" or db_row.get("is_falling"): alert = build_alert(db_row) loop.run_until_complete(_persist_alert(session_maker, alert)) + log.info( + "alert created device_id=%s patient_id=%s reason=%s", + alert.get("device_id"), + alert.get("patient_id"), + alert.get("reason"), + ) ALERT_DELIVERY_ATTEMPTS.labels(topic=settings.kafka_topic_alerts).inc() try: future = producer.send( diff --git a/tests/unit/test_alert_payload.py b/tests/unit/test_alert_payload.py new file mode 100644 index 0000000..1c7040b --- /dev/null +++ b/tests/unit/test_alert_payload.py @@ -0,0 +1,23 @@ +from d2.api.app import _build_alert_payload, _message_from_alert, _severity_from_alert + + +def test_message_from_alert_defaults(): + assert _message_from_alert({}) == "threshold_breach" + + +def test_severity_critical_on_fall(): + alert = {"is_falling": True} + assert _severity_from_alert(alert) == "critical" + + +def test_severity_warning_on_status(): + alert = {"hr_status": "warning"} + assert _severity_from_alert(alert) == "warning" + + +def test_build_alert_payload_defaults_patient(): + alert = {"reason": "hr=130 (critical)", "hr_status": "critical"} + payload = _build_alert_payload(alert, None, "Unknown patient") + assert payload.patientId == "P000" + assert payload.severity == "critical" + assert payload.message == "hr=130 (critical)" From ba1b684d2c1358285c82fc85fcd615bdf49d3a3d Mon Sep 17 00:00:00 2001 From: Izzat Nisfer Date: Thu, 7 May 2026 16:01:46 +0530 Subject: [PATCH 2/6] feat: async handling and add unit tests for auth token extraction and origin splitting --- src/d2/api/app.py | 4 +- src/d2/api/routes/alerts.py | 8 ++-- tests/integration/test_socket_bridge.py | 60 +++++++++++++++++++++++ tests/unit/test_alerts_route.py | 64 +++++++++++++++++++++++++ tests/unit/test_app_logic.py | 25 ++++++++++ 5 files changed, 155 insertions(+), 6 deletions(-) create mode 100644 tests/integration/test_socket_bridge.py create mode 100644 tests/unit/test_alerts_route.py create mode 100644 tests/unit/test_app_logic.py diff --git a/src/d2/api/app.py b/src/d2/api/app.py index 5c9001b..0e33620 100644 --- a/src/d2/api/app.py +++ b/src/d2/api/app.py @@ -153,7 +153,7 @@ async def join_room(sid, payload): if isinstance(payload, dict): room_id = payload.get("roomId") or payload.get("room_id") if room_id: - socket_server.enter_room(sid, str(room_id)) + await socket_server.enter_room(sid, str(room_id)) log.info("[Socket] %s joined room=%s", sid, room_id) @socket_server.on("leave_room") @@ -162,7 +162,7 @@ async def leave_room(sid, payload): if isinstance(payload, dict): room_id = payload.get("roomId") or payload.get("room_id") if room_id: - socket_server.leave_room(sid, str(room_id)) + await socket_server.leave_room(sid, str(room_id)) log.info("[Socket] %s left room=%s", sid, room_id) diff --git a/src/d2/api/routes/alerts.py b/src/d2/api/routes/alerts.py index 703931e..6e9eb22 100644 --- a/src/d2/api/routes/alerts.py +++ b/src/d2/api/routes/alerts.py @@ -21,11 +21,11 @@ "/alerts", response_model=list[LiveAlert], tags=["alerts"], - dependencies=[Depends(require_roles("doctor", "admin", "patient"))], + dependencies=[require_roles("doctor", "admin", "patient")], ) async def get_alerts( session: DbSession, - user: dict = Depends(require_roles("doctor", "admin", "patient")), + user: dict = require_roles("doctor", "admin", "patient"), patient_id: int | None = None, limit: int = 500 ) -> list[LiveAlert]: @@ -52,9 +52,9 @@ async def get_alerts( "/alerts/{alert_id}/acknowledge", response_model=AcknowledgeAlertResponse, tags=["alerts"], - dependencies=[Depends(require_roles("doctor", "admin"))], + dependencies=[require_roles("doctor", "admin")], ) -async def acknowledge_alert(alert_id: str, session: DbSession, user: dict = Depends(require_roles("doctor", "admin"))) -> AcknowledgeAlertResponse: +async def acknowledge_alert(alert_id: str, session: DbSession, user: dict = require_roles("doctor", "admin")) -> AcknowledgeAlertResponse: try: aid = int(alert_id) except ValueError as exc: diff --git a/tests/integration/test_socket_bridge.py b/tests/integration/test_socket_bridge.py new file mode 100644 index 0000000..3cd966c --- /dev/null +++ b/tests/integration/test_socket_bridge.py @@ -0,0 +1,60 @@ +import threading +import uvicorn +import asyncio +import time +import requests +import socketio +import pytest +from d2.api.app import app +from d2.api.session import InMemorySessionStore + +@pytest.fixture(scope="module") +def server(): + # Use a thread so we share the InMemorySessionStore + def run_server(): + uvicorn.run(app, host="127.0.0.1", port=8889, log_level="error") + + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + + # Wait for server to start + timeout = 10 + start = time.time() + while time.time() - start < timeout: + try: + requests.get("http://127.0.0.1:8889/v1/health") + break + except: + time.sleep(0.5) + + yield "http://127.0.0.1:8889" + # Thread will be killed when main process exits since it's a daemon + +@pytest.mark.asyncio +async def test_socket_connection_refused_no_auth(server): + sio = socketio.AsyncClient() + with pytest.raises(socketio.exceptions.ConnectionError): + await sio.connect(server, auth={}) + +@pytest.mark.asyncio +async def test_socket_connection_success_with_auth(server): + # Create a fake session in memory + token = InMemorySessionStore.create_session({"role": "doctor", "sub": "test-sub"}) + + sio = socketio.AsyncClient() + await sio.connect(server, auth={"token": token}) + assert sio.sid is not None + await sio.disconnect() + +@pytest.mark.asyncio +async def test_socket_join_room(server): + token = InMemorySessionStore.create_session({"role": "patient", "sub": "p-sub"}) + sio = socketio.AsyncClient() + await sio.connect(server, auth={"token": token}) + + # We can't easily verify the room membership from the client side without emitting back + # But we can at least ensure the event doesn't crash the server + await sio.emit("join_room", {"roomId": "patient:123"}) + await asyncio.sleep(0.1) + + await sio.disconnect() diff --git a/tests/unit/test_alerts_route.py b/tests/unit/test_alerts_route.py new file mode 100644 index 0000000..cc9ca05 --- /dev/null +++ b/tests/unit/test_alerts_route.py @@ -0,0 +1,64 @@ +import pytest +from fastapi.testclient import TestClient +from datetime import datetime, timezone +from d2.api.app import app +from d2.api.deps import get_db, get_current_user +from d2.db import repository + +# We need to clear overrides after each test if we change them +@pytest.fixture(autouse=True) +def clear_overrides(): + app.dependency_overrides.clear() + yield + app.dependency_overrides.clear() + +async def fake_db(): + yield None + +app.dependency_overrides[get_db] = fake_db +client = TestClient(app) + +def test_get_alerts_as_doctor(monkeypatch): + user = {"role": "doctor", "preferred_username": "dr_test"} + app.dependency_overrides[get_current_user] = lambda: user + + async def fake_list(session, limit=500, patient_id=None): + return [{"id": 1, "device_id": "d1", "patient_id": patient_id, "ts": datetime.now(timezone.utc), "severity": "warning", "reason": "test", "acknowledged": False, "patient_full_name": "John Doe"}] + + monkeypatch.setattr(repository, "list_live_alerts", fake_list) + + response = client.get("/api/alerts") + assert response.status_code == 200 + assert response.json()[0]["patientId"] == "" # row_to_live_alert returns "" if patient_id is None + +def test_get_alerts_as_patient(monkeypatch): + user = {"role": "patient", "auth_subject_id": 123} + app.dependency_overrides[get_current_user] = lambda: user + + async def fake_fetch_pid(session, auth_id): + assert auth_id == 123 + return 456 + + async def fake_list(session, limit=500, patient_id=None): + assert patient_id == 456 + return [{"id": 1, "device_id": "d1", "patient_id": 456, "ts": datetime.now(timezone.utc), "severity": "warning", "reason": "test", "acknowledged": False, "patient_full_name": "Patient X"}] + + monkeypatch.setattr(repository, "fetch_patient_id_by_auth_subject", fake_fetch_pid) + monkeypatch.setattr(repository, "list_live_alerts", fake_list) + + response = client.get("/api/alerts") + assert response.status_code == 200 + assert response.json()[0]["patientId"] == "P456" + +def test_get_alerts_patient_not_found(monkeypatch): + user = {"role": "patient", "auth_subject_id": 123} + app.dependency_overrides[get_current_user] = lambda: user + + async def fake_fetch_pid(session, auth_id): + return None + + monkeypatch.setattr(repository, "fetch_patient_id_by_auth_subject", fake_fetch_pid) + + response = client.get("/api/alerts") + assert response.status_code == 403 + assert "Patient profile not found" in response.text diff --git a/tests/unit/test_app_logic.py b/tests/unit/test_app_logic.py new file mode 100644 index 0000000..d408f0d --- /dev/null +++ b/tests/unit/test_app_logic.py @@ -0,0 +1,25 @@ +from d2.api.app import _extract_auth_token, _split_origins + +def test_extract_auth_token_bearer(): + auth = {"token": "Bearer my-secret-token"} + assert _extract_auth_token(auth) == "my-secret-token" + +def test_extract_auth_token_plain(): + auth = {"token": "plain-token"} + assert _extract_auth_token(auth) == "plain-token" + +def test_extract_auth_token_invalid(): + assert _extract_auth_token(None) is None + assert _extract_auth_token({}) is None + assert _extract_auth_token({"no_token": "foo"}) is None + +def test_split_origins_single(): + assert _split_origins("http://localhost:3000") == ["http://localhost:3000"] + +def test_split_origins_multiple(): + raw = "http://localhost:3000, https://myapp.com " + assert _split_origins(raw) == ["http://localhost:3000", "https://myapp.com"] + +def test_split_origins_empty(): + assert _split_origins(None) == [] + assert _split_origins("") == [] From 2ff55dce960d0e4bd9c5053c7cdc5587004e92d0 Mon Sep 17 00:00:00 2001 From: Izzat Nisfer Date: Fri, 8 May 2026 10:27:27 +0530 Subject: [PATCH 3/6] feat: add simulate_alert script for publishing vitals payloads to Kafka --- scripts/simulate_alert.py | 58 ++++++++ tests/integration/test_worker_pipeline.py | 154 ++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 scripts/simulate_alert.py create mode 100644 tests/integration/test_worker_pipeline.py diff --git a/scripts/simulate_alert.py b/scripts/simulate_alert.py new file mode 100644 index 0000000..24a3cea --- /dev/null +++ b/scripts/simulate_alert.py @@ -0,0 +1,58 @@ +import argparse +import time +from datetime import datetime, timezone + +from d2.config import get_settings +from d2.streaming.kafka_io import create_producer + + +def _build_payload(*, critical: bool) -> dict[str, object]: + now = datetime.now(tz=timezone.utc) + payload = { + "timestamp": now.timestamp(), + "device_id": "device_1", + "pulse": 78, + "blood_oxygen": 93.0, + "temperature": 36.8, + "IsFalling": False, + "IsMoving": False, + } + if critical: + payload["blood_oxygen"] = 85.0 + return payload + + +def main() -> int: + parser = argparse.ArgumentParser(description="Publish a single vitals payload to Kafka.") + parser.add_argument( + "--critical", + action="store_true", + help="Send a critical SpO2 reading to trigger alert logic downstream.", + ) + args = parser.parse_args() + + settings = get_settings() + producer = create_producer(settings) + payload = _build_payload(critical=args.critical) + + try: + future = producer.send( + settings.kafka_topic_vitals_raw, + key=str(payload["device_id"]).encode("utf-8"), + value=payload, + ) + future.get(timeout=10) + producer.flush(timeout=10) + print( + f"Sent payload to {settings.kafka_topic_vitals_raw} at {time.strftime('%H:%M:%S')}: {payload}" + ) + return 0 + except Exception as exc: + print(f"Failed to publish payload: {exc}") + return 1 + finally: + producer.close() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/integration/test_worker_pipeline.py b/tests/integration/test_worker_pipeline.py new file mode 100644 index 0000000..d873c5f --- /dev/null +++ b/tests/integration/test_worker_pipeline.py @@ -0,0 +1,154 @@ +import asyncio +import sys +from collections import defaultdict, deque +from types import ModuleType, SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from d2.config import Settings + +_influx_module = ModuleType("influxdb_client") +_influx_module.InfluxDBClient = object +_influx_module.Point = object +_influx_client_module = ModuleType("influxdb_client.client") +_influx_write_api_module = ModuleType("influxdb_client.client.write_api") +_influx_write_api_module.SYNCHRONOUS = object() +_influx_domain_module = ModuleType("influxdb_client.domain") +_influx_write_precision_module = ModuleType("influxdb_client.domain.write_precision") + +class _WritePrecision: + NS = "ns" + +_influx_write_precision_module.WritePrecision = _WritePrecision + +sys.modules.setdefault("influxdb_client", _influx_module) +sys.modules.setdefault("influxdb_client.client", _influx_client_module) +sys.modules.setdefault("influxdb_client.client.write_api", _influx_write_api_module) +sys.modules.setdefault("influxdb_client.domain", _influx_domain_module) +sys.modules.setdefault("influxdb_client.domain.write_precision", _influx_write_precision_module) + +from d2.streaming import worker + + +def _make_message(payload): + return SimpleNamespace(value=payload) + + +@pytest.fixture +def event_loop(): + loop = asyncio.new_event_loop() + try: + yield loop + finally: + loop.close() + + +def _base_settings(): + return Settings( + kafka_topic_vitals_raw="vitals.raw", + kafka_topic_vitals_scored="vitals.scored", + kafka_topic_alerts="alerts", + model_version="test-0", + ) + + +def _healthy_payload(): + return { + "device_id": "dev-1", + "pulse": 72, + "blood_oxygen": 98.0, + "temperature": 36.8, + "timestamp": 1715123456, + "IsFalling": False, + "IsMoving": True, + } + + +def _critical_payload(): + payload = _healthy_payload() + payload["blood_oxygen"] = 85.0 + return payload + + +def test_worker_pipeline_normal_payload(event_loop): + settings = _base_settings() + history = defaultdict(lambda: deque(maxlen=4000)) + producer = MagicMock() + session_maker = MagicMock() + + score = { + "overall_status": "normal", + "risk_level": 0, + "risk_score": 0.05, + "model_version": settings.model_version, + } + + with patch.object(worker, "_upsert_device", AsyncMock()) as upsert_device, patch.object( + worker, "_persist_prediction", AsyncMock() + ) as persist_prediction, patch.object(worker, "_persist_alert", AsyncMock()) as persist_alert, patch.object( + worker.inference, "score_window", return_value=score + ): + worker._process_one( + _make_message(_healthy_payload()), + history, + model="mock-model", + settings=settings, + session_maker=session_maker, + producer=producer, + loop=event_loop, + influx_writer=None, + lstm_model=None, + ) + + upsert_device.assert_called_once_with(session_maker, "dev-1") + persist_prediction.assert_called_once() + persist_alert.assert_not_called() + + assert producer.send.call_count == 1 + args, kwargs = producer.send.call_args + assert args[0] == settings.kafka_topic_vitals_scored + assert kwargs["value"]["overall_status"] == "normal" + + +def test_worker_pipeline_critical_payload(event_loop): + settings = _base_settings() + history = defaultdict(lambda: deque(maxlen=4000)) + producer = MagicMock() + session_maker = MagicMock() + + score = { + "overall_status": "critical", + "risk_level": 2, + "risk_score": 0.95, + "model_version": settings.model_version, + } + + with patch.object(worker, "_upsert_device", AsyncMock()) as upsert_device, patch.object( + worker, "_persist_prediction", AsyncMock() + ) as persist_prediction, patch.object(worker, "_persist_alert", AsyncMock()) as persist_alert, patch.object( + worker.inference, "score_window", return_value=score + ): + worker._process_one( + _make_message(_critical_payload()), + history, + model="mock-model", + settings=settings, + session_maker=session_maker, + producer=producer, + loop=event_loop, + influx_writer=None, + lstm_model=None, + ) + + upsert_device.assert_called_once_with(session_maker, "dev-1") + persist_prediction.assert_called_once() + persist_alert.assert_called_once() + + assert producer.send.call_count == 2 + scored_call = producer.send.call_args_list[0] + alert_call = producer.send.call_args_list[1] + assert scored_call.args[0] == settings.kafka_topic_vitals_scored + assert scored_call.kwargs["value"]["overall_status"] == "critical" + assert alert_call.args[0] == settings.kafka_topic_alerts + assert "reason" in alert_call.kwargs["value"] From 6e1ccb74ea4b15ae6a703607b8240dcfd133c8d3 Mon Sep 17 00:00:00 2001 From: Izzat Nisfer Date: Sun, 10 May 2026 07:18:54 +0530 Subject: [PATCH 4/6] fix(ci): add aiokafka, python-socketio, starlette to dev extras for test imports --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 91cca7e..7e775e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,9 @@ dev = [ "pytest-asyncio>=0.23", "httpx>=0.24", "ruff>=0.1.0", + "aiokafka>=0.10", + "python-socketio>=5.11", + "starlette>=0.27", ] [project.scripts] From a11d39e2fe797d3d0464806c33a2ebfea6a3cfad Mon Sep 17 00:00:00 2001 From: Izzat Nisfer Date: Sun, 10 May 2026 07:20:04 +0530 Subject: [PATCH 5/6] fix: add itsdangerous dependency required by starlette SessionMiddleware --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index a999dd9..0982d90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,9 @@ aiokafka>=0.10 # Real-time alerts (Socket.IO) python-socketio>=5.11 +# Session middleware (SessionMiddleware requires itsdangerous) +itsdangerous>=2.1 + # Time-series (InfluxDB 2.x — worker vitals writes) influxdb-client>=1.38,<2 From 27525ef42d4ff18396c35a8630e471b4c43483cd Mon Sep 17 00:00:00 2001 From: Izzat Nisfer Date: Sun, 10 May 2026 07:23:18 +0530 Subject: [PATCH 6/6] fix: add require_roles dependency to deps.py and clean up alerts route signatures --- src/d2/api/deps.py | 33 ++++++++++++++++++++++++--------- src/d2/api/routes/alerts.py | 23 ++++++++++++----------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/d2/api/deps.py b/src/d2/api/deps.py index dc90ef4..36bb57e 100644 --- a/src/d2/api/deps.py +++ b/src/d2/api/deps.py @@ -1,16 +1,8 @@ -""" -FastAPI dependencies: settings, database sessions, (future) auth. - -`get_db` yields one async SQLAlchemy session per request. Auth is a stub until -D4 wires Keycloak — routes that need it should still `Depends(get_current_user)` -so the call site doesn't change later. -""" - from __future__ import annotations from typing import Annotated, AsyncIterator -from fastapi import Depends +from fastapi import Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from d2.config import Settings, get_settings @@ -41,3 +33,26 @@ async def get_current_user() -> dict[str, str]: CurrentUser = Annotated[dict, Depends(get_current_user)] + + +def require_roles(*allowed_roles: str): + """FastAPI dependency factory — returns a Depends() that enforces role access. + + Usage: + @router.get("/path", dependencies=[Depends(require_roles("doctor", "admin"))]) + async def handler(user: dict = Depends(require_roles("doctor", "admin"))): + ... + + While get_current_user is a stub, all roles pass. Once D4 wires real JWT + validation, this will enforce actual Keycloak roles without changing call sites. + """ + async def _check(user: dict = Depends(get_current_user)) -> dict: + role = str(user.get("role") or "").lower() + if allowed_roles and role not in [r.lower() for r in allowed_roles]: + raise HTTPException( + status_code=403, + detail=f"Role '{role}' is not permitted. Required: {list(allowed_roles)}", + ) + return user + + return Depends(_check) diff --git a/src/d2/api/routes/alerts.py b/src/d2/api/routes/alerts.py index 6e9eb22..5401575 100644 --- a/src/d2/api/routes/alerts.py +++ b/src/d2/api/routes/alerts.py @@ -7,7 +7,7 @@ from sqlalchemy import text from d2.api.converters import row_to_live_alert -from d2.api.deps import DbSession, require_roles +from d2.api.deps import DbSession, get_current_user, require_roles from d2.api.schemas import AcknowledgeAlertResponse, LiveAlert from d2.db import repository from d2.observability.metrics import ALERTS_ACKNOWLEDGED @@ -21,7 +21,6 @@ "/alerts", response_model=list[LiveAlert], tags=["alerts"], - dependencies=[require_roles("doctor", "admin", "patient")], ) async def get_alerts( session: DbSession, @@ -31,19 +30,19 @@ async def get_alerts( ) -> list[LiveAlert]: """All alerts, newest first (see ``alerts`` + ``patients`` join).""" role = str(user.get("role") or "").lower() - + # If the user is a patient, they can only see their own alerts if "patient" in role: auth_subject_id = user.get("auth_subject_id") if not auth_subject_id: raise HTTPException(status_code=403, detail="Patient profile not linked") - + pid = await repository.fetch_patient_id_by_auth_subject(session, auth_subject_id) if not pid: raise HTTPException(status_code=403, detail="Patient profile not found") - + patient_id = pid - + rows = await repository.list_live_alerts(session, limit=limit, patient_id=patient_id) return [row_to_live_alert(dict(r)) for r in rows] @@ -52,9 +51,12 @@ async def get_alerts( "/alerts/{alert_id}/acknowledge", response_model=AcknowledgeAlertResponse, tags=["alerts"], - dependencies=[require_roles("doctor", "admin")], ) -async def acknowledge_alert(alert_id: str, session: DbSession, user: dict = require_roles("doctor", "admin")) -> AcknowledgeAlertResponse: +async def acknowledge_alert( + alert_id: str, + session: DbSession, + user: dict = require_roles("doctor", "admin"), +) -> AcknowledgeAlertResponse: try: aid = int(alert_id) except ValueError as exc: @@ -68,11 +70,10 @@ async def acknowledge_alert(alert_id: str, session: DbSession, user: dict = requ await repository.acknowledge_alert(session, aid) ALERTS_ACKNOWLEDGED.inc() - - # Log acknowledgment + username = user.get("preferred_username") or user.get("email") or "unknown_user" log.info(f"Alert {alert_id} acknowledged by user: {username} (auth_subject_id: {user.get('auth_subject_id')})") - + return AcknowledgeAlertResponse( success=True, message=f"Alert {alert_id} acknowledged successfully" )