diff --git a/requirements.txt b/requirements.txt index 7703e8f..5386a15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,11 @@ vitaldb>=1.4 # Streaming (member: Kafka track) kafka-python>=2.0 +aiokafka>=0.10 + +# Real-time alerts +python-socketio>=5.11 +itsdangerous>=2.1 # Time-series (InfluxDB 2.x — worker vitals writes) influxdb-client>=1.38,<2 @@ -34,4 +39,6 @@ opentelemetry-api>=1.20 # Dev / test pytest>=7.0 +pytest-asyncio>=0.23 httpx>=0.24 +websockets>=11.0 diff --git a/scripts/simulate_alert.py b/scripts/simulate_alert.py new file mode 100644 index 0000000..24a3cea --- /dev/null +++ b/scripts/simulate_alert.py @@ -0,0 +1,58 @@ +import argparse +import time +from datetime import datetime, timezone + +from d2.config import get_settings +from d2.streaming.kafka_io import create_producer + + +def _build_payload(*, critical: bool) -> dict[str, object]: + now = datetime.now(tz=timezone.utc) + payload = { + "timestamp": now.timestamp(), + "device_id": "device_1", + "pulse": 78, + "blood_oxygen": 93.0, + "temperature": 36.8, + "IsFalling": False, + "IsMoving": False, + } + if critical: + payload["blood_oxygen"] = 85.0 + return payload + + +def main() -> int: + parser = argparse.ArgumentParser(description="Publish a single vitals payload to Kafka.") + parser.add_argument( + "--critical", + action="store_true", + help="Send a critical SpO2 reading to trigger alert logic downstream.", + ) + args = parser.parse_args() + + settings = get_settings() + producer = create_producer(settings) + payload = _build_payload(critical=args.critical) + + try: + future = producer.send( + settings.kafka_topic_vitals_raw, + key=str(payload["device_id"]).encode("utf-8"), + value=payload, + ) + future.get(timeout=10) + producer.flush(timeout=10) + print( + f"Sent payload to {settings.kafka_topic_vitals_raw} at {time.strftime('%H:%M:%S')}: {payload}" + ) + return 0 + except Exception as exc: + print(f"Failed to publish payload: {exc}") + return 1 + finally: + producer.close() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/d2/api/app.py b/src/d2/api/app.py index 6eae59d..5a51cd3 100644 --- a/src/d2/api/app.py +++ b/src/d2/api/app.py @@ -8,11 +8,17 @@ from __future__ import annotations +import asyncio +import json import logging import socket +import ssl import time from contextlib import asynccontextmanager +from typing import Any +import socketio +from aiokafka import AIOKafkaConsumer from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.requests import Request @@ -21,9 +27,17 @@ from sqlalchemy.exc import OperationalError from d2.api.routes import health, vitals, vitals_routes, patients, alerts, devices, reports +from d2.api.schemas.alert import AlertPayload from d2.config import database_hostname, get_settings -from d2.db.connection import dispose_engine -from d2.observability.metrics import API_ERRORS_TOTAL, API_REQUESTS_TOTAL, API_RESPONSE_SECONDS +from d2.db import repository +from d2.db.connection import dispose_engine, get_session_maker +from d2.observability.metrics import ( + ALERTS_DELIVERED, + ALERTS_DELIVERY_FAILED, + API_ERRORS_TOTAL, + API_REQUESTS_TOTAL, + API_RESPONSE_SECONDS, +) log = logging.getLogger(__name__) @@ -34,6 +48,146 @@ ) } +def _split_origins(raw: str | None) -> list[str]: + if not raw: + return [] + return [origin.strip() for origin in raw.split(",") if origin.strip()] + +def format_patient_slug(patient_id: int) -> str: + return f"PAT-{patient_id:04d}" + +def _severity_from_alert(alert: dict[str, Any]) -> str: + if alert.get("is_falling"): + return "critical" + for key in ("hr_status", "spo2_status", "temp_status"): + if alert.get(key) == "critical": + return "critical" + return "warning" + +def _message_from_alert(alert: dict[str, Any]) -> str: + reason = alert.get("reason", "Unknown anomaly detected") + return reason.replace("_", " ").capitalize() + +async def _resolve_patient_context( + session: Any, alert: dict[str, Any] +) -> tuple[int | None, str]: + patient_id = alert.get("patient_id") + patient_name = "Unknown Patient" + + if patient_id: + row = await repository.fetch_patient_by_id(session, int(patient_id)) + if row and row.get("full_name"): + patient_name = str(row["full_name"]).strip() + + return (int(patient_id) if patient_id else None, patient_name) + +def _build_alert_payload( + alert: dict[str, Any], + patient_id: int | None, + patient_name: str, +) -> AlertPayload: + patient_slug = format_patient_slug(patient_id or 0) + return AlertPayload( + patientId=patient_slug, + patientName=patient_name, + message=_message_from_alert(alert), + severity=_severity_from_alert(alert), + ) + +def _build_kafka_ssl_context(settings: Any) -> ssl.SSLContext | None: + if settings.kafka_security_protocol.upper() == "PLAINTEXT": + return None + context = ssl.create_default_context() + # In main, we might not have resolve_repo_relative_path, so just use raw paths or skip if not found + ca = getattr(settings, "kafka_ssl_cafile", None) + cert = getattr(settings, "kafka_ssl_certfile", None) + key = getattr(settings, "kafka_ssl_keyfile", None) + if ca: + context.load_verify_locations(cafile=ca) + if cert and key: + context.load_cert_chain(certfile=cert, keyfile=key) + return context + +def _register_socket_handlers( + socket_server: socketio.AsyncServer, + *, + require_auth: bool, +) -> None: + @socket_server.event + async def connect(sid, environ, auth): + # We removed InMemorySessionStore logic here, we just accept connections + # Keycloak integration will happen in D4. + log.info("[Socket] client connected sid=%s", sid) + + @socket_server.event + async def disconnect(sid): + log.info("[Socket] client disconnected sid=%s", sid) + + @socket_server.on("join_room") + async def join_room(sid, payload): + room_id = None + if isinstance(payload, dict): + room_id = payload.get("roomId") or payload.get("room_id") + if room_id: + await socket_server.enter_room(sid, str(room_id)) + log.info("[Socket] %s joined room=%s", sid, room_id) + + @socket_server.on("leave_room") + async def leave_room(sid, payload): + room_id = None + if isinstance(payload, dict): + room_id = payload.get("roomId") or payload.get("room_id") + if room_id: + await socket_server.leave_room(sid, str(room_id)) + log.info("[Socket] %s left room=%s", sid, room_id) + +async def _consume_alerts(socket_server: socketio.AsyncServer, settings: Any) -> None: + ssl_context = _build_kafka_ssl_context(settings) + session_maker = get_session_maker() + retry_delay = 5 + + while True: + try: + consumer = AIOKafkaConsumer( + getattr(settings, "kafka_topic_alerts", "alerts"), + bootstrap_servers=settings.kafka_bootstrap_servers.split(","), + group_id="d2-socketio-broadcaster", + value_deserializer=lambda m: json.loads(m.decode("utf-8")), + enable_auto_commit=True, + security_protocol=settings.kafka_security_protocol, + ssl_context=ssl_context, + ) + await consumer.start() + log.info("Socket.IO alert consumer started topic=%s", getattr(settings, "kafka_topic_alerts", "alerts")) + async for msg in consumer: + alert = msg.value + async with session_maker() as session: + patient_id, patient_name = await _resolve_patient_context(session, alert) + payload = _build_alert_payload(alert, patient_id, patient_name) + severity = payload.severity + + try: + await socket_server.emit("alert", payload.model_dump()) + ALERTS_DELIVERED.labels(severity=severity).inc() + log.info( + "[Socket] alert broadcast severity=%s patient_id=%s", + severity, + patient_id, + ) + except Exception: + ALERTS_DELIVERY_FAILED.labels(severity=severity).inc() + log.exception("[Socket] alert broadcast failed") + except asyncio.CancelledError: + break + except Exception: + log.exception("Socket.IO alert consumer failed; retrying in %ss", retry_delay) + await asyncio.sleep(retry_delay) + finally: + try: + await consumer.stop() + except Exception: + pass + @asynccontextmanager async def _lifespan(app: FastAPI): @@ -44,11 +198,26 @@ async def _lifespan(app: FastAPI): database_hostname(settings.database_url), settings.pg_ssl_required, ) + + socket_server: socketio.AsyncServer | None = getattr(app.state, "socketio_server", None) + kafka_task: asyncio.Task | None = None + if socket_server is not None: + kafka_task = asyncio.create_task(_consume_alerts(socket_server, settings)) + yield + + if kafka_task is not None: + kafka_task.cancel() + try: + await kafka_task + except asyncio.CancelledError: + pass + await dispose_engine() def create_app() -> FastAPI: + settings = get_settings() app = FastAPI( title="D2 Intelligence API", description="Data & Intelligence layer — vitals scoring, history, alerts integration.", @@ -56,6 +225,14 @@ def create_app() -> FastAPI: lifespan=_lifespan, ) + socketio_origins = _split_origins(settings.socketio_allow_origins or settings.cors_allow_origins) + socket_server = socketio.AsyncServer( + async_mode="asgi", + cors_allowed_origins=socketio_origins or ["*"], + ) + _register_socket_handlers(socket_server, require_auth=settings.socketio_require_auth) + app.state.socketio_server = socket_server + # CORS — allow all origins for dev / demo; restrict to D3 origin in production. app.add_middleware( CORSMiddleware, @@ -88,6 +265,8 @@ async def _metrics_middleware(request: Request, call_next): status=str(status_code), ).inc() + app.mount("/socket.io", socketio.ASGIApp(socket_server)) + @app.get("/metrics", include_in_schema=False) def metrics() -> Response: """Prometheus scrape endpoint.""" @@ -98,7 +277,6 @@ def _db_unreachable_response( ) -> JSONResponse: log.warning("Database connection failed (%s): %s", label, exc) return JSONResponse(status_code=503, content=_DB_UNREACHABLE_BODY) - @app.exception_handler(OperationalError) async def _db_sa_error(request: Request, exc: OperationalError) -> JSONResponse: diff --git a/src/d2/api/routes/alerts.py b/src/d2/api/routes/alerts.py index 0b6ab2b..448f950 100644 --- a/src/d2/api/routes/alerts.py +++ b/src/d2/api/routes/alerts.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from fastapi import APIRouter, HTTPException from sqlalchemy import text @@ -9,14 +10,16 @@ from d2.api.deps import DbSession from d2.api.schemas import AcknowledgeAlertResponse, LiveAlert from d2.db import repository +from d2.observability.metrics import ALERTS_ACKNOWLEDGED +log = logging.getLogger(__name__) router = APIRouter() @router.get("/alerts", response_model=list[LiveAlert], tags=["alerts"]) -async def get_alerts(session: DbSession, limit: int = 500) -> list[LiveAlert]: +async def get_alerts(session: DbSession, limit: int = 500, patient_id: int | None = None) -> list[LiveAlert]: """All alerts, newest first (see ``alerts`` + ``patients`` join).""" - rows = await repository.list_live_alerts(session, limit=limit) + rows = await repository.list_live_alerts(session, limit=limit, patient_id=patient_id) return [row_to_live_alert(dict(r)) for r in rows] @@ -34,6 +37,9 @@ async def acknowledge_alert(alert_id: str, session: DbSession) -> AcknowledgeAle raise HTTPException(status_code=404, detail=f"Alert with id '{alert_id}' not found") await repository.acknowledge_alert(session, aid) + ALERTS_ACKNOWLEDGED.inc() + log.info(f"Alert {alert_id} acknowledged") + return AcknowledgeAlertResponse( success=True, message=f"Alert {alert_id} acknowledged successfully" ) diff --git a/src/d2/config.py b/src/d2/config.py index 4fcd05d..8977e2b 100644 --- a/src/d2/config.py +++ b/src/d2/config.py @@ -90,6 +90,11 @@ def influx_configured(self) -> bool: # Optional Keycloak checks (D4 will supply values) keycloak_well_known_url: str | None = None + # CORS and Socket.IO + cors_allow_origins: str | None = None + socketio_allow_origins: str | None = None + socketio_require_auth: bool = False + @model_validator(mode="after") def ensure_asyncpg_driver_in_database_url(self) -> Self: """SQLAlchemy AsyncEngine requires asyncpg. Supabase / dashboards often paste diff --git a/src/d2/db/repository.py b/src/d2/db/repository.py index 2b92036..26bc834 100644 --- a/src/d2/db/repository.py +++ b/src/d2/db/repository.py @@ -322,6 +322,7 @@ async def list_alerts( FROM alerts a LEFT JOIN patients p ON p.id = a.patient_id LEFT JOIN users u ON u.id = p.user_id + WHERE (:patient_id::int IS NULL OR a.patient_id = :patient_id) ORDER BY a.ts DESC LIMIT :limit """ @@ -505,8 +506,8 @@ async def fetch_vitals_for_patient_range( return [dict(r) for r in result.mappings().all()] -async def list_live_alerts(session: AsyncSession, limit: int = 500) -> list[dict[str, Any]]: - result = await session.execute(_LIST_ALERTS_LIVE, {"limit": limit}) +async def list_live_alerts(session: AsyncSession, limit: int = 500, patient_id: int | None = None) -> list[dict[str, Any]]: + result = await session.execute(_LIST_ALERTS_LIVE, {"limit": limit, "patient_id": patient_id}) return [dict(r) for r in result.mappings().all()] diff --git a/src/d2/observability/metrics.py b/src/d2/observability/metrics.py index 1021e08..1c370da 100644 --- a/src/d2/observability/metrics.py +++ b/src/d2/observability/metrics.py @@ -114,4 +114,21 @@ "d2_model_inference_failures_total", "Number of model inference failures (by model).", labelnames=("model", "stage"), +) + +ALERTS_DELIVERED = Counter( + "d2_alerts_delivered_total", + "Alerts successfully broadcast to Socket.IO clients.", + labelnames=("severity",) +) + +ALERTS_DELIVERY_FAILED = Counter( + "d2_alerts_delivery_failed_total", + "Alerts that failed to broadcast to Socket.IO clients.", + labelnames=("severity",) +) + +ALERTS_ACKNOWLEDGED = Counter( + "d2_alerts_acknowledged_total", + "Alerts explicitly acknowledged by a user via the API." ) \ No newline at end of file diff --git a/tests/integration/test_socket_bridge.py b/tests/integration/test_socket_bridge.py new file mode 100644 index 0000000..afb1fca --- /dev/null +++ b/tests/integration/test_socket_bridge.py @@ -0,0 +1,43 @@ +import threading +import uvicorn +import asyncio +import time +import requests +import socketio +import pytest +from d2.api.app import app + +@pytest.fixture(scope="module") +def server(): + def run_server(): + uvicorn.run(app, host="127.0.0.1", port=8889, log_level="error") + + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + + # Wait for server to start + timeout = 10 + start = time.time() + while time.time() - start < timeout: + try: + requests.get("http://127.0.0.1:8889/v1/health") + break + except: + time.sleep(0.5) + + yield "http://127.0.0.1:8889" + +def test_socket_connection_success(server): + sio = socketio.Client() + sio.connect(server) + assert sio.sid is not None + sio.disconnect() + +def test_socket_join_room(server): + sio = socketio.Client() + sio.connect(server) + + sio.emit("join_room", {"roomId": "patient:123"}) + time.sleep(0.1) + + sio.disconnect() diff --git a/tests/integration/test_worker_pipeline.py b/tests/integration/test_worker_pipeline.py new file mode 100644 index 0000000..d873c5f --- /dev/null +++ b/tests/integration/test_worker_pipeline.py @@ -0,0 +1,154 @@ +import asyncio +import sys +from collections import defaultdict, deque +from types import ModuleType, SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from d2.config import Settings + +_influx_module = ModuleType("influxdb_client") +_influx_module.InfluxDBClient = object +_influx_module.Point = object +_influx_client_module = ModuleType("influxdb_client.client") +_influx_write_api_module = ModuleType("influxdb_client.client.write_api") +_influx_write_api_module.SYNCHRONOUS = object() +_influx_domain_module = ModuleType("influxdb_client.domain") +_influx_write_precision_module = ModuleType("influxdb_client.domain.write_precision") + +class _WritePrecision: + NS = "ns" + +_influx_write_precision_module.WritePrecision = _WritePrecision + +sys.modules.setdefault("influxdb_client", _influx_module) +sys.modules.setdefault("influxdb_client.client", _influx_client_module) +sys.modules.setdefault("influxdb_client.client.write_api", _influx_write_api_module) +sys.modules.setdefault("influxdb_client.domain", _influx_domain_module) +sys.modules.setdefault("influxdb_client.domain.write_precision", _influx_write_precision_module) + +from d2.streaming import worker + + +def _make_message(payload): + return SimpleNamespace(value=payload) + + +@pytest.fixture +def event_loop(): + loop = asyncio.new_event_loop() + try: + yield loop + finally: + loop.close() + + +def _base_settings(): + return Settings( + kafka_topic_vitals_raw="vitals.raw", + kafka_topic_vitals_scored="vitals.scored", + kafka_topic_alerts="alerts", + model_version="test-0", + ) + + +def _healthy_payload(): + return { + "device_id": "dev-1", + "pulse": 72, + "blood_oxygen": 98.0, + "temperature": 36.8, + "timestamp": 1715123456, + "IsFalling": False, + "IsMoving": True, + } + + +def _critical_payload(): + payload = _healthy_payload() + payload["blood_oxygen"] = 85.0 + return payload + + +def test_worker_pipeline_normal_payload(event_loop): + settings = _base_settings() + history = defaultdict(lambda: deque(maxlen=4000)) + producer = MagicMock() + session_maker = MagicMock() + + score = { + "overall_status": "normal", + "risk_level": 0, + "risk_score": 0.05, + "model_version": settings.model_version, + } + + with patch.object(worker, "_upsert_device", AsyncMock()) as upsert_device, patch.object( + worker, "_persist_prediction", AsyncMock() + ) as persist_prediction, patch.object(worker, "_persist_alert", AsyncMock()) as persist_alert, patch.object( + worker.inference, "score_window", return_value=score + ): + worker._process_one( + _make_message(_healthy_payload()), + history, + model="mock-model", + settings=settings, + session_maker=session_maker, + producer=producer, + loop=event_loop, + influx_writer=None, + lstm_model=None, + ) + + upsert_device.assert_called_once_with(session_maker, "dev-1") + persist_prediction.assert_called_once() + persist_alert.assert_not_called() + + assert producer.send.call_count == 1 + args, kwargs = producer.send.call_args + assert args[0] == settings.kafka_topic_vitals_scored + assert kwargs["value"]["overall_status"] == "normal" + + +def test_worker_pipeline_critical_payload(event_loop): + settings = _base_settings() + history = defaultdict(lambda: deque(maxlen=4000)) + producer = MagicMock() + session_maker = MagicMock() + + score = { + "overall_status": "critical", + "risk_level": 2, + "risk_score": 0.95, + "model_version": settings.model_version, + } + + with patch.object(worker, "_upsert_device", AsyncMock()) as upsert_device, patch.object( + worker, "_persist_prediction", AsyncMock() + ) as persist_prediction, patch.object(worker, "_persist_alert", AsyncMock()) as persist_alert, patch.object( + worker.inference, "score_window", return_value=score + ): + worker._process_one( + _make_message(_critical_payload()), + history, + model="mock-model", + settings=settings, + session_maker=session_maker, + producer=producer, + loop=event_loop, + influx_writer=None, + lstm_model=None, + ) + + upsert_device.assert_called_once_with(session_maker, "dev-1") + persist_prediction.assert_called_once() + persist_alert.assert_called_once() + + assert producer.send.call_count == 2 + scored_call = producer.send.call_args_list[0] + alert_call = producer.send.call_args_list[1] + assert scored_call.args[0] == settings.kafka_topic_vitals_scored + assert scored_call.kwargs["value"]["overall_status"] == "critical" + assert alert_call.args[0] == settings.kafka_topic_alerts + assert "reason" in alert_call.kwargs["value"] diff --git a/tests/unit/test_alert_payload.py b/tests/unit/test_alert_payload.py new file mode 100644 index 0000000..5a0b4c0 --- /dev/null +++ b/tests/unit/test_alert_payload.py @@ -0,0 +1,23 @@ +from d2.api.app import _build_alert_payload, _message_from_alert, _severity_from_alert + + +def test_message_from_alert_defaults(): + assert _message_from_alert({}) == "Unknown anomaly detected" + + +def test_severity_critical_on_fall(): + alert = {"is_falling": True} + assert _severity_from_alert(alert) == "critical" + + +def test_severity_warning_on_status(): + alert = {"hr_status": "warning"} + assert _severity_from_alert(alert) == "warning" + + +def test_build_alert_payload_defaults_patient(): + alert = {"reason": "hr=130 (critical)", "hr_status": "critical"} + payload = _build_alert_payload(alert, None, "Unknown patient") + assert payload.patientId == "PAT-0000" + assert payload.severity == "critical" + assert payload.message == "Hr=130 (critical)" diff --git a/tests/unit/test_alerts_route.py b/tests/unit/test_alerts_route.py new file mode 100644 index 0000000..98274a9 --- /dev/null +++ b/tests/unit/test_alerts_route.py @@ -0,0 +1,65 @@ +import pytest +from fastapi.testclient import TestClient +from datetime import datetime, timezone +from d2.api.app import app +from d2.api.deps import get_db +from d2.db import repository + +@pytest.fixture(autouse=True) +def clear_overrides(): + app.dependency_overrides.clear() + yield + app.dependency_overrides.clear() + +async def fake_db(): + yield None + +app.dependency_overrides[get_db] = fake_db +client = TestClient(app) + +def test_get_alerts(monkeypatch): + async def fake_list(session, limit=500, patient_id=None): + return [{"id": 1, "device_id": "d1", "patient_id": patient_id, "ts": datetime.now(timezone.utc), "severity": "warning", "reason": "test", "acknowledged": False, "patient_full_name": "John Doe"}] + + monkeypatch.setattr(repository, "list_live_alerts", fake_list) + + response = client.get("/api/alerts") + assert response.status_code == 200 + assert response.json()[0]["patientId"] == "" + +def test_get_alerts_with_patient_id(monkeypatch): + async def fake_list(session, limit=500, patient_id=None): + assert patient_id == 456 + return [{"id": 1, "device_id": "d1", "patient_id": 456, "ts": datetime.now(timezone.utc), "severity": "warning", "reason": "test", "acknowledged": False, "patient_full_name": "Patient X"}] + + monkeypatch.setattr(repository, "list_live_alerts", fake_list) + + response = client.get("/api/alerts?patient_id=456") + assert response.status_code == 200 + assert response.json()[0]["patientId"] == "P456" + +def test_acknowledge_alert_not_found(monkeypatch): + async def fake_execute(*args, **kwargs): + class MockResult: + def scalar_one_or_none(self): + return None + return MockResult() + + monkeypatch.setattr("d2.api.deps.AsyncSession.execute", fake_execute, raising=False) + + # We also need to mock the session execute directly since it's a fake_db yielding None + # Let's just override the route dependency with a mock session + class MockSession: + async def execute(self, *args, **kwargs): + class MockResult: + def scalar_one_or_none(self): + return None + return MockResult() + + async def fake_mock_db(): + yield MockSession() + + app.dependency_overrides[get_db] = fake_mock_db + + response = client.patch("/api/alerts/999/acknowledge") + assert response.status_code == 404