Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ dev = [
"pytest-asyncio>=0.23",
"httpx>=0.24",
"ruff>=0.1.0",
"aiokafka>=0.10",
"python-socketio>=5.11",
"starlette>=0.27",
]

[project.scripts]
Expand Down
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ vitaldb>=1.4

# Streaming (member: Kafka track)
kafka-python>=2.0
aiokafka>=0.10

# Real-time alerts (Socket.IO)
python-socketio>=5.11

# Session middleware (SessionMiddleware requires itsdangerous)
itsdangerous>=2.1

# Time-series (InfluxDB 2.x — worker vitals writes)
influxdb-client>=1.38,<2
Expand Down
58 changes: 58 additions & 0 deletions scripts/simulate_alert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse
import time
from datetime import datetime, timezone

from d2.config import get_settings
from d2.streaming.kafka_io import create_producer


def _build_payload(*, critical: bool) -> dict[str, object]:
now = datetime.now(tz=timezone.utc)
payload = {
"timestamp": now.timestamp(),
"device_id": "device_1",
"pulse": 78,
"blood_oxygen": 93.0,
"temperature": 36.8,
"IsFalling": False,
"IsMoving": False,
}
if critical:
payload["blood_oxygen"] = 85.0
return payload


def main() -> int:
parser = argparse.ArgumentParser(description="Publish a single vitals payload to Kafka.")
parser.add_argument(
"--critical",
action="store_true",
help="Send a critical SpO2 reading to trigger alert logic downstream.",
)
args = parser.parse_args()

settings = get_settings()
producer = create_producer(settings)
payload = _build_payload(critical=args.critical)

try:
future = producer.send(
settings.kafka_topic_vitals_raw,
key=str(payload["device_id"]).encode("utf-8"),
value=payload,
)
future.get(timeout=10)
producer.flush(timeout=10)
print(
f"Sent payload to {settings.kafka_topic_vitals_raw} at {time.strftime('%H:%M:%S')}: {payload}"
)
return 0
except Exception as exc:
print(f"Failed to publish payload: {exc}")
return 1
finally:
producer.close()


if __name__ == "__main__":
raise SystemExit(main())
227 changes: 220 additions & 7 deletions src/d2/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,33 @@

from __future__ import annotations

import logging
import socket
import time
import ssl
from contextlib import asynccontextmanager
from typing import Any

from fastapi import FastAPI
import socketio
from aiokafka import AIOKafkaConsumer
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse, Response
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
from sqlalchemy.exc import OperationalError
from starlette.middleware.sessions import SessionMiddleware

from d2.api.routes import health, vitals, vitals_routes, patients, alerts, devices, reports
from d2.config import database_hostname, get_settings
from d2.db.connection import dispose_engine
from d2.observability.metrics import API_ERRORS_TOTAL, API_REQUESTS_TOTAL, API_RESPONSE_SECONDS
from d2.api.deps import get_current_user
from d2.api.id_utils import format_patient_slug
from d2.api.routes import auth, alerts, devices, health, patients, reports, vitals, vitals_routes
from d2.api.schemas.alert import AlertPayload
from d2.api.session import InMemorySessionStore
from d2.config import database_hostname, get_settings, resolve_repo_relative_path
from d2.db import repository
from d2.db.connection import dispose_engine, get_session_maker
from d2.observability.metrics import (
ALERTS_DELIVERED, ALERTS_DELIVERY_FAILED,
API_ERRORS_TOTAL, API_REQUESTS_TOTAL, API_RESPONSE_SECONDS,
)

log = logging.getLogger(__name__)

Expand All @@ -35,6 +46,174 @@
}


def _split_origins(raw: str | None) -> list[str]:
if not raw:
return []
return [origin.strip() for origin in raw.split(",") if origin.strip()]


def _extract_auth_token(auth: Any) -> str | None:
if not isinstance(auth, dict):
return None
token = auth.get("token")
if not token:
return None
token_str = str(token).strip()
if token_str.lower().startswith("bearer "):
return token_str[7:].strip()
return token_str


def _severity_from_alert(alert: dict[str, Any]) -> str:
if alert.get("is_falling"):
return "critical"
for key in ("hr_status", "spo2_status", "temp_status"):
if str(alert.get(key) or "").lower() == "critical":
return "critical"
for key in ("hr_status", "spo2_status", "temp_status"):
if str(alert.get(key) or "").lower() == "warning":
return "warning"
return "info"


def _message_from_alert(alert: dict[str, Any]) -> str:
reason = str(alert.get("reason") or "").strip()
return reason if reason else "threshold_breach"


async def _resolve_patient_context(
session: Any,
alert: dict[str, Any],
) -> tuple[int | None, str]:
patient_id = alert.get("patient_id")
if not patient_id and alert.get("device_id"):
patient_id = await repository.fetch_assigned_patient_for_device(
session, str(alert["device_id"])
)

patient_name = "Unknown patient"
if patient_id:
row = await repository.fetch_patient_by_id(session, int(patient_id))
if row and row.get("full_name"):
patient_name = str(row["full_name"]).strip()

return (int(patient_id) if patient_id else None, patient_name)


def _build_alert_payload(
alert: dict[str, Any],
patient_id: int | None,
patient_name: str,
) -> AlertPayload:
patient_slug = format_patient_slug(patient_id or 0)
return AlertPayload(
patientId=patient_slug,
patientName=patient_name,
message=_message_from_alert(alert),
severity=_severity_from_alert(alert),
)


def _build_kafka_ssl_context(settings: Any) -> ssl.SSLContext | None:
if settings.kafka_security_protocol.upper() == "PLAINTEXT":
return None
context = ssl.create_default_context()
ca = resolve_repo_relative_path(settings.kafka_ssl_cafile)
cert = resolve_repo_relative_path(settings.kafka_ssl_certfile)
key = resolve_repo_relative_path(settings.kafka_ssl_keyfile)
if ca:
context.load_verify_locations(cafile=ca)
if cert and key:
context.load_cert_chain(certfile=cert, keyfile=key)
return context


def _register_socket_handlers(
socket_server: socketio.AsyncServer,
*,
require_auth: bool,
) -> None:
@socket_server.event
async def connect(sid, environ, auth):
token = _extract_auth_token(auth)
if require_auth:
if not token:
raise socketio.exceptions.ConnectionRefusedError("auth required")
if not InMemorySessionStore.get_session(token):
raise socketio.exceptions.ConnectionRefusedError("invalid session")
log.info("[Socket] client connected sid=%s", sid)

@socket_server.event
async def disconnect(sid):
log.info("[Socket] client disconnected sid=%s", sid)

@socket_server.on("join_room")
async def join_room(sid, payload):
room_id = None
if isinstance(payload, dict):
room_id = payload.get("roomId") or payload.get("room_id")
if room_id:
await socket_server.enter_room(sid, str(room_id))
log.info("[Socket] %s joined room=%s", sid, room_id)

@socket_server.on("leave_room")
async def leave_room(sid, payload):
room_id = None
if isinstance(payload, dict):
room_id = payload.get("roomId") or payload.get("room_id")
if room_id:
await socket_server.leave_room(sid, str(room_id))
log.info("[Socket] %s left room=%s", sid, room_id)


async def _consume_alerts(socket_server: socketio.AsyncServer, settings: Any) -> None:
ssl_context = _build_kafka_ssl_context(settings)
session_maker = get_session_maker(settings)
retry_delay = 5

while True:
consumer = AIOKafkaConsumer(
settings.kafka_topic_alerts,
bootstrap_servers=settings.kafka_bootstrap_servers.split(","),
group_id="d2-socketio-broadcaster",
value_deserializer=lambda m: json.loads(m.decode("utf-8")),
enable_auto_commit=True,
security_protocol=settings.kafka_security_protocol,
ssl_context=ssl_context,
)
try:
await consumer.start()
log.info("Socket.IO alert consumer started topic=%s", settings.kafka_topic_alerts)
async for msg in consumer:
alert = msg.value
async with session_maker() as session:
patient_id, patient_name = await _resolve_patient_context(session, alert)
payload = _build_alert_payload(alert, patient_id, patient_name)
severity = payload.severity

try:
await socket_server.emit("alert", payload.model_dump())
ALERTS_DELIVERED.labels(severity=severity).inc()
log.info(
"[Socket] alert broadcast severity=%s patient_id=%s",
severity,
patient_id,
)
except Exception:
ALERTS_DELIVERY_FAILED.labels(severity=severity).inc()
log.exception("[Socket] alert broadcast failed")
except asyncio.CancelledError:
break
except Exception:
log.exception("Socket.IO alert consumer failed; retrying in %ss", retry_delay)
await asyncio.sleep(retry_delay)
finally:
try:
await consumer.stop()
except Exception:
log.exception("Socket.IO alert consumer stop failed")


@asynccontextmanager
async def _lifespan(app: FastAPI):
"""Log which DB the process uses; release engine on shutdown."""
Expand All @@ -44,18 +223,49 @@ async def _lifespan(app: FastAPI):
database_hostname(settings.database_url),
settings.pg_ssl_required,
)

socket_server: socketio.AsyncServer | None = getattr(app.state, "socketio_server", None)
kafka_task: asyncio.Task | None = None
if socket_server is not None:
kafka_task = asyncio.create_task(_consume_alerts(socket_server, settings))

yield

if kafka_task is not None:
kafka_task.cancel()
try:
await kafka_task
except asyncio.CancelledError:
pass
await dispose_engine()


def create_app() -> FastAPI:
settings = get_settings()

app = FastAPI(
title="D2 Intelligence API",
description="Data & Intelligence layer — vitals scoring, history, alerts integration.",
version="0.1.0",
lifespan=_lifespan,
)

socketio_origins = _split_origins(settings.socketio_allow_origins or settings.cors_allow_origins)
socket_server = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=socketio_origins or ["*"],
)
_register_socket_handlers(socket_server, require_auth=settings.socketio_require_auth)
app.state.socketio_server = socket_server

app.add_middleware(
SessionMiddleware,
secret_key=settings.session_secret_key,
https_only=False,
same_site="lax",
max_age=600,
)

# CORS — allow all origins for dev / demo; restrict to D3 origin in production.
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -88,6 +298,9 @@ async def _metrics_middleware(request: Request, call_next):
status=str(status_code),
).inc()

# Mount Socket.IO server on /socket.io (matches frontend default path).
app.mount("/socket.io", socketio.ASGIApp(socket_server))

@app.get("/metrics", include_in_schema=False)
def metrics() -> Response:
"""Prometheus scrape endpoint."""
Expand Down
Loading