Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ vitaldb>=1.4

# Streaming (member: Kafka track)
kafka-python>=2.0
aiokafka>=0.10

# Real-time alerts
python-socketio>=5.11
itsdangerous>=2.1

# Time-series (InfluxDB 2.x — worker vitals writes)
influxdb-client>=1.38,<2
Expand All @@ -34,4 +39,6 @@ opentelemetry-api>=1.20

# Dev / test
pytest>=7.0
pytest-asyncio>=0.23
httpx>=0.24
websockets>=11.0
58 changes: 58 additions & 0 deletions scripts/simulate_alert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse
import time
from datetime import datetime, timezone

from d2.config import get_settings
from d2.streaming.kafka_io import create_producer


def _build_payload(*, critical: bool) -> dict[str, object]:
now = datetime.now(tz=timezone.utc)
payload = {
"timestamp": now.timestamp(),
"device_id": "device_1",
"pulse": 78,
"blood_oxygen": 93.0,
"temperature": 36.8,
"IsFalling": False,
"IsMoving": False,
}
if critical:
payload["blood_oxygen"] = 85.0
return payload


def main() -> int:
parser = argparse.ArgumentParser(description="Publish a single vitals payload to Kafka.")
parser.add_argument(
"--critical",
action="store_true",
help="Send a critical SpO2 reading to trigger alert logic downstream.",
)
args = parser.parse_args()

settings = get_settings()
producer = create_producer(settings)
payload = _build_payload(critical=args.critical)

try:
future = producer.send(
settings.kafka_topic_vitals_raw,
key=str(payload["device_id"]).encode("utf-8"),
value=payload,
)
future.get(timeout=10)
producer.flush(timeout=10)
print(
f"Sent payload to {settings.kafka_topic_vitals_raw} at {time.strftime('%H:%M:%S')}: {payload}"
)
return 0
except Exception as exc:
print(f"Failed to publish payload: {exc}")
return 1
finally:
producer.close()


if __name__ == "__main__":
raise SystemExit(main())
184 changes: 181 additions & 3 deletions src/d2/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@

from __future__ import annotations

import asyncio
import json
import logging
import socket
import ssl
import time
from contextlib import asynccontextmanager
from typing import Any

import socketio
from aiokafka import AIOKafkaConsumer
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
Expand All @@ -21,9 +27,17 @@
from sqlalchemy.exc import OperationalError

from d2.api.routes import health, vitals, vitals_routes, patients, alerts, devices, reports
from d2.api.schemas.alert import AlertPayload
from d2.config import database_hostname, get_settings
from d2.db.connection import dispose_engine
from d2.observability.metrics import API_ERRORS_TOTAL, API_REQUESTS_TOTAL, API_RESPONSE_SECONDS
from d2.db import repository
from d2.db.connection import dispose_engine, get_session_maker
from d2.observability.metrics import (
ALERTS_DELIVERED,
ALERTS_DELIVERY_FAILED,
API_ERRORS_TOTAL,
API_REQUESTS_TOTAL,
API_RESPONSE_SECONDS,
)

log = logging.getLogger(__name__)

Expand All @@ -34,6 +48,146 @@
)
}

def _split_origins(raw: str | None) -> list[str]:
if not raw:
return []
return [origin.strip() for origin in raw.split(",") if origin.strip()]

def format_patient_slug(patient_id: int) -> str:
return f"PAT-{patient_id:04d}"

def _severity_from_alert(alert: dict[str, Any]) -> str:
if alert.get("is_falling"):
return "critical"
for key in ("hr_status", "spo2_status", "temp_status"):
if alert.get(key) == "critical":
return "critical"
return "warning"

def _message_from_alert(alert: dict[str, Any]) -> str:
reason = alert.get("reason", "Unknown anomaly detected")
return reason.replace("_", " ").capitalize()

async def _resolve_patient_context(
session: Any, alert: dict[str, Any]
) -> tuple[int | None, str]:
patient_id = alert.get("patient_id")
patient_name = "Unknown Patient"

if patient_id:
row = await repository.fetch_patient_by_id(session, int(patient_id))
if row and row.get("full_name"):
patient_name = str(row["full_name"]).strip()

return (int(patient_id) if patient_id else None, patient_name)

def _build_alert_payload(
alert: dict[str, Any],
patient_id: int | None,
patient_name: str,
) -> AlertPayload:
patient_slug = format_patient_slug(patient_id or 0)
return AlertPayload(
patientId=patient_slug,
patientName=patient_name,
message=_message_from_alert(alert),
severity=_severity_from_alert(alert),
)

def _build_kafka_ssl_context(settings: Any) -> ssl.SSLContext | None:
if settings.kafka_security_protocol.upper() == "PLAINTEXT":
return None
context = ssl.create_default_context()
# In main, we might not have resolve_repo_relative_path, so just use raw paths or skip if not found
ca = getattr(settings, "kafka_ssl_cafile", None)
cert = getattr(settings, "kafka_ssl_certfile", None)
key = getattr(settings, "kafka_ssl_keyfile", None)
if ca:
context.load_verify_locations(cafile=ca)
if cert and key:
context.load_cert_chain(certfile=cert, keyfile=key)
return context

def _register_socket_handlers(
socket_server: socketio.AsyncServer,
*,
require_auth: bool,
) -> None:
@socket_server.event
async def connect(sid, environ, auth):
# We removed InMemorySessionStore logic here, we just accept connections
# Keycloak integration will happen in D4.
log.info("[Socket] client connected sid=%s", sid)

@socket_server.event
async def disconnect(sid):
log.info("[Socket] client disconnected sid=%s", sid)

@socket_server.on("join_room")
async def join_room(sid, payload):
room_id = None
if isinstance(payload, dict):
room_id = payload.get("roomId") or payload.get("room_id")
if room_id:
await socket_server.enter_room(sid, str(room_id))
log.info("[Socket] %s joined room=%s", sid, room_id)

@socket_server.on("leave_room")
async def leave_room(sid, payload):
room_id = None
if isinstance(payload, dict):
room_id = payload.get("roomId") or payload.get("room_id")
if room_id:
await socket_server.leave_room(sid, str(room_id))
log.info("[Socket] %s left room=%s", sid, room_id)

async def _consume_alerts(socket_server: socketio.AsyncServer, settings: Any) -> None:
ssl_context = _build_kafka_ssl_context(settings)
session_maker = get_session_maker()
retry_delay = 5

while True:
try:
consumer = AIOKafkaConsumer(
getattr(settings, "kafka_topic_alerts", "alerts"),
bootstrap_servers=settings.kafka_bootstrap_servers.split(","),
group_id="d2-socketio-broadcaster",
value_deserializer=lambda m: json.loads(m.decode("utf-8")),
enable_auto_commit=True,
security_protocol=settings.kafka_security_protocol,
ssl_context=ssl_context,
)
await consumer.start()
log.info("Socket.IO alert consumer started topic=%s", getattr(settings, "kafka_topic_alerts", "alerts"))
async for msg in consumer:
alert = msg.value
async with session_maker() as session:
patient_id, patient_name = await _resolve_patient_context(session, alert)
payload = _build_alert_payload(alert, patient_id, patient_name)
severity = payload.severity

try:
await socket_server.emit("alert", payload.model_dump())
ALERTS_DELIVERED.labels(severity=severity).inc()
log.info(
"[Socket] alert broadcast severity=%s patient_id=%s",
severity,
patient_id,
)
except Exception:
ALERTS_DELIVERY_FAILED.labels(severity=severity).inc()
log.exception("[Socket] alert broadcast failed")
except asyncio.CancelledError:
break
except Exception:
log.exception("Socket.IO alert consumer failed; retrying in %ss", retry_delay)
await asyncio.sleep(retry_delay)
finally:
try:
await consumer.stop()
except Exception:
pass


@asynccontextmanager
async def _lifespan(app: FastAPI):
Expand All @@ -44,18 +198,41 @@ async def _lifespan(app: FastAPI):
database_hostname(settings.database_url),
settings.pg_ssl_required,
)

socket_server: socketio.AsyncServer | None = getattr(app.state, "socketio_server", None)
kafka_task: asyncio.Task | None = None
if socket_server is not None:
kafka_task = asyncio.create_task(_consume_alerts(socket_server, settings))

yield

if kafka_task is not None:
kafka_task.cancel()
try:
await kafka_task
except asyncio.CancelledError:
pass

await dispose_engine()


def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(
title="D2 Intelligence API",
description="Data & Intelligence layer — vitals scoring, history, alerts integration.",
version="0.1.0",
lifespan=_lifespan,
)

socketio_origins = _split_origins(settings.socketio_allow_origins or settings.cors_allow_origins)
socket_server = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=socketio_origins or ["*"],
)
_register_socket_handlers(socket_server, require_auth=settings.socketio_require_auth)
app.state.socketio_server = socket_server

# CORS — allow all origins for dev / demo; restrict to D3 origin in production.
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -88,6 +265,8 @@ async def _metrics_middleware(request: Request, call_next):
status=str(status_code),
).inc()

app.mount("/socket.io", socketio.ASGIApp(socket_server))

@app.get("/metrics", include_in_schema=False)
def metrics() -> Response:
"""Prometheus scrape endpoint."""
Expand All @@ -98,7 +277,6 @@ def _db_unreachable_response(
) -> JSONResponse:
log.warning("Database connection failed (%s): %s", label, exc)
return JSONResponse(status_code=503, content=_DB_UNREACHABLE_BODY)


@app.exception_handler(OperationalError)
async def _db_sa_error(request: Request, exc: OperationalError) -> JSONResponse:
Expand Down
10 changes: 8 additions & 2 deletions src/d2/api/routes/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

from __future__ import annotations

import logging
from fastapi import APIRouter, HTTPException
from sqlalchemy import text

from d2.api.converters import row_to_live_alert
from d2.api.deps import DbSession
from d2.api.schemas import AcknowledgeAlertResponse, LiveAlert
from d2.db import repository
from d2.observability.metrics import ALERTS_ACKNOWLEDGED

log = logging.getLogger(__name__)
router = APIRouter()


@router.get("/alerts", response_model=list[LiveAlert], tags=["alerts"])
async def get_alerts(session: DbSession, limit: int = 500) -> list[LiveAlert]:
async def get_alerts(session: DbSession, limit: int = 500, patient_id: int | None = None) -> list[LiveAlert]:
"""All alerts, newest first (see ``alerts`` + ``patients`` join)."""
rows = await repository.list_live_alerts(session, limit=limit)
rows = await repository.list_live_alerts(session, limit=limit, patient_id=patient_id)
return [row_to_live_alert(dict(r)) for r in rows]


Expand All @@ -34,6 +37,9 @@ async def acknowledge_alert(alert_id: str, session: DbSession) -> AcknowledgeAle
raise HTTPException(status_code=404, detail=f"Alert with id '{alert_id}' not found")

await repository.acknowledge_alert(session, aid)
ALERTS_ACKNOWLEDGED.inc()
log.info(f"Alert {alert_id} acknowledged")

return AcknowledgeAlertResponse(
success=True, message=f"Alert {alert_id} acknowledged successfully"
)
5 changes: 5 additions & 0 deletions src/d2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def influx_configured(self) -> bool:
# Optional Keycloak checks (D4 will supply values)
keycloak_well_known_url: str | None = None

# CORS and Socket.IO
cors_allow_origins: str | None = None
socketio_allow_origins: str | None = None
socketio_require_auth: bool = False

@model_validator(mode="after")
def ensure_asyncpg_driver_in_database_url(self) -> Self:
"""SQLAlchemy AsyncEngine requires asyncpg. Supabase / dashboards often paste
Expand Down
5 changes: 3 additions & 2 deletions src/d2/db/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ async def list_alerts(
FROM alerts a
LEFT JOIN patients p ON p.id = a.patient_id
LEFT JOIN users u ON u.id = p.user_id
WHERE (:patient_id::int IS NULL OR a.patient_id = :patient_id)
ORDER BY a.ts DESC
LIMIT :limit
"""
Expand Down Expand Up @@ -505,8 +506,8 @@ async def fetch_vitals_for_patient_range(
return [dict(r) for r in result.mappings().all()]


async def list_live_alerts(session: AsyncSession, limit: int = 500) -> list[dict[str, Any]]:
result = await session.execute(_LIST_ALERTS_LIVE, {"limit": limit})
async def list_live_alerts(session: AsyncSession, limit: int = 500, patient_id: int | None = None) -> list[dict[str, Any]]:
result = await session.execute(_LIST_ALERTS_LIVE, {"limit": limit, "patient_id": patient_id})
return [dict(r) for r in result.mappings().all()]


Expand Down
Loading
Loading