From c37d92a5922f50eb15a9474e34c6ea89b275b24b Mon Sep 17 00:00:00 2001 From: Eniola-12 Date: Mon, 8 Jun 2026 00:02:34 +0100 Subject: [PATCH] feat: add audit logging for link generation and access --- app/adapters/persistence/postgres_storage.py | 29 +++++ app/api/dependencies/services.py | 4 + app/api/router.py | 21 ++++ app/domain/models.py | 27 +++++ app/main.py | 2 + app/ports/storage_port.py | 8 ++ app/services/__init__.py | 2 + app/services/audit_service.py | 112 +++++++++++++++++++ migrations/0006_add_audit_logs.py | 25 +++++ pyproject.toml | 1 + 10 files changed, 231 insertions(+) create mode 100644 app/services/audit_service.py create mode 100644 migrations/0006_add_audit_logs.py diff --git a/app/adapters/persistence/postgres_storage.py b/app/adapters/persistence/postgres_storage.py index 517e5a1..7bb365c 100644 --- a/app/adapters/persistence/postgres_storage.py +++ b/app/adapters/persistence/postgres_storage.py @@ -9,11 +9,14 @@ """ from __future__ import annotations +from tkinter.constants import INSERT + import asyncpg # type: ignore from datetime import datetime from typing import Iterable, Optional from app.domain import ( + AuditLog, Document, DocumentGroup, DocumentPermission, @@ -25,6 +28,8 @@ from app.infra.factories import StorageFactory from app.ports.storage_port import StoragePort +from HexShare.app.infra.factories import AuthenticatorFactory + class PostgresStorage(StoragePort): def __init__(self, pool: asyncpg.Pool) -> None: @@ -550,6 +555,30 @@ async def update_document_room( row = await con.fetchrow(sql, tenant_id, document_id, room_id) return self._row_to_document(row) if row else None + async def save_audit_log(self, log: AuditLog) -> None: + sql = """ + INSERT INTO audit_logs ( + id, tenant_id, event_type, link_id, document_id, + actor, ip_address, device, location, timestamp + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, $9, $10 + ) + """ + async with self._pool.acquire() as con: + await con.execute( + sql, + log.id, + log.tenant_id, + log.event_type, + log.link_id, + log.document_id, + log.actor, + log.ip_address, + log.device, + log.location, + log.timestamp, + ) @StorageFactory.register("postgres") def create_postgres_storage(*, pool, **_) -> StoragePort: diff --git a/app/api/dependencies/services.py b/app/api/dependencies/services.py index b15c57b..c58885e 100644 --- a/app/api/dependencies/services.py +++ b/app/api/dependencies/services.py @@ -47,3 +47,7 @@ def get_oidc_client_service(request: Request): def get_iam_policy(request: Request): return request.app.state.iam_policy + + +def get_audit_service(request: Request): + return request.app.state.audit_service \ No newline at end of file diff --git a/app/api/router.py b/app/api/router.py index a556da4..122dab7 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -10,6 +10,7 @@ from app.api.dependencies.services import ( get_analytics_service, + get_audit_service, get_document_group_service, get_document_service, get_iam_policy, @@ -32,6 +33,7 @@ ) from app.services import ( AnalyticsService, + AuditService, DocumentGroupService, DocumentService, DocumentProcessingError, @@ -225,6 +227,7 @@ async def list_document_links( @router.post("/documents/{document_id}/links", response_model=ShareLinkResponse) async def create_link( document_id: str, + request: Request, expires_in: int = Query(3600, description="Seconds until link expiry"), can_download: bool = Query(False), can_print: bool = Query(False), @@ -233,6 +236,7 @@ async def create_link( principal: TenantPrincipal = Depends(get_tenant_auth()), document_service: DocumentService = Depends(get_document_service), link_service: LinkService = Depends(get_link_service), + audit_service: AuditService = Depends(get_audit_service), ) -> ShareLinkResponse: try: await document_service.require_document_access( @@ -254,6 +258,14 @@ async def create_link( require_email=require_email, allowed_emails=allowed_emails, ) + await audit_service.log_link_created( + tenant_id= principal.tenant_id, + link_id= link.id, + document_id=document_id, + actor=principal.user_id, + ip_address=request.client.host if request.client else None, + user_agent=request.headers.get("user-agent"), + ) token = await link_service.generate_share_token(link) return _serialize_link(link, token) @@ -545,6 +557,7 @@ async def create_view_session( request: Request, share_auth: ShareTokenDependency = Depends(get_share_auth), viewer_service: ViewerService = Depends(get_viewer_service), + audit_service: AuditService = Depends(get_audit_service), ) -> CreateViewSessionResponse: claims: ShareTokenClaims = share_auth(token) try: @@ -556,6 +569,14 @@ async def create_view_session( ip_address=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), ) + await audit_service.log_link_accessed( + tenant_id=claims.tenant_id, + link_id=claims.link_id, + document_id=claims.document_id, + actor=payload.email, + ip_address=request.client.host if request.client else None, + user_agent=request.headers.get("user-agent"), + ) delivery = await viewer_service.describe_view_session_delivery( tenant_id=claims.tenant_id, session_id=session.id, diff --git a/app/domain/models.py b/app/domain/models.py index 8776646..61a7075 100644 --- a/app/domain/models.py +++ b/app/domain/models.py @@ -213,3 +213,30 @@ def validate_page_number(cls, v, values): # type: ignore[override] if event_type == EventType.PAGE_VIEW and v is None: raise ValueError("page_number is required for page_view events") return v + + +class AuditLog(BaseModel): + """Records who did what with a share link and from where. + + ip_address: + Raw IP address of the request. + device: + Human-readable device name parsed from User-Agent + e.g. 'iPhone 13 Pro / Mobile Safari'. + location: + City and country derived from IP address + e.g. 'Lagos, Nigeria'. + timestamp: + When the event occurred. + """ + + id: str + tenant_id: str + event_type: str + link_id: str + document_id: str + actor: str + ip_address: Optional[str] = None + device: Optional[str] = None + location: Optional[str] = None + timestamp: datetime diff --git a/app/main.py b/app/main.py index 39a5a16..a613cfe 100644 --- a/app/main.py +++ b/app/main.py @@ -30,6 +30,7 @@ ) from app.services import ( AnalyticsService, + AuditService, DocumentProcessor, DocumentGroupService, DocumentService, @@ -136,6 +137,7 @@ async def lifespan(fastapi_app: FastAPI): fastapi_app.state.link_service = link_service fastapi_app.state.viewer_service = viewer_service fastapi_app.state.analytics_service = AnalyticsService(persistence_layer) + fastapi_app.state.audit_service = AuditService(persistence_layer) fastapi_app.state.access_control = access_control fastapi_app.state.tenant_auth = TenantAuthDependency(authenticator=authenticator) fastapi_app.state.share_auth = ShareTokenDependency(token_port=token_adapter) diff --git a/app/ports/storage_port.py b/app/ports/storage_port.py index a8440e1..ea99c99 100644 --- a/app/ports/storage_port.py +++ b/app/ports/storage_port.py @@ -15,6 +15,7 @@ from datetime import datetime from app.domain import ( + AuditLog, Document, DocumentGroup, DocumentPermission, @@ -23,6 +24,8 @@ ViewEvent, ) +from HexShare.app.domain.models import AuditLog + class StoragePort(ABC): """Abstract base class for document and link persistence.""" @@ -190,3 +193,8 @@ async def update_document_room( self, *, tenant_id: str, document_id: str, room_id: Optional[str] ) -> Optional[Document]: """Move a document to a group (room_id) or remove from group (room_id=None).""" + + + @abstractmethod + async def save_audit_log(self, log: AuditLog) -> None: + """Persist an audit log entry""" \ No newline at end of file diff --git a/app/services/__init__.py b/app/services/__init__.py index 8e85be2..7b796f7 100644 --- a/app/services/__init__.py +++ b/app/services/__init__.py @@ -3,6 +3,7 @@ """ from .analytics_service import AnalyticsService +from .audit_service import AuditService from .document_processor import ( DocumentProcessor, DocumentProcessingError, @@ -20,6 +21,7 @@ __all__ = [ "AnalyticsService", + "AuditService", "DocumentProcessor", "DocumentProcessingError", "DocumentGroupService", diff --git a/app/services/audit_service.py b/app/services/audit_service.py new file mode 100644 index 0000000..78cc477 --- /dev/null +++ b/app/services/audit_service.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone + +import httpx +from user_agents import parse as parse_ua + +from app.domain import AuditLog +from app.ports.storage_port import StoragePort + + +class AuditService: + """Records audit events for link creation and access.""" + + def __init__(self, storage: StoragePort) -> None: + self._storage = storage + + @staticmethod + def _now() -> datetime: + return datetime.now(timezone.utc).replace(tzinfo=None) + + @staticmethod + def _parse_device(user_agent: str | None) -> str: + """Parse User-Agent string into a human-readable device name.""" + if not user_agent: + return "Unknown Device" + ua = parse_ua(user_agent) + device = ua.device.family + browser = ua.browser.family + os = ua.os.family + if device and device.lower() != "other": + return f"{device} / {browser}" + return f"{os} / {browser}" + + @staticmethod + async def _lookup_location(ip_address: str | None) -> str: + """Derive city and country from IP address using ip-api.com.""" + if not ip_address or ip_address in ("127.0.0.1", "localhost"): + return "Unknown Location" + try: + async with httpx.AsyncClient(timeout=3.0) as client: + response = await client.get( + f"http://ip-api.com/json/{ip_address}", + params={"fields": "status,city,country"}, + ) + data = response.json() + if data.get("status") == "success": + city = data.get("city", "") + country = data.get("country", "") + return f"{city}, {country}".strip(", ") + except Exception: + pass + return "Unknown Location" + + async def log_link_created( + self, + *, + tenant_id: str, + link_id: str, + document_id: str, + actor: str, + ip_address: str | None, + user_agent: str | None, + ) -> None: + """Log a link.created audit event.""" + device, location = await asyncio.gather( + asyncio.to_thread(self._parse_device, user_agent), + self._lookup_location(ip_address), + ) + log = AuditLog( + id=self._storage.generate_id("aud"), + tenant_id=tenant_id, + event_type="link.created", + link_id=link_id, + document_id=document_id, + actor=actor, + ip_address=ip_address, + device=device, + location=location, + timestamp=self._now(), + ) + await self._storage.save_audit_log(log) + + async def log_link_accessed( + self, + *, + tenant_id: str, + link_id: str, + document_id: str, + actor: str | None, + ip_address: str | None, + user_agent: str | None, + ) -> None: + """Log a link.accessed audit event.""" + device, location = await asyncio.gather( + asyncio.to_thread(self._parse_device, user_agent), + self._lookup_location(ip_address), + ) + log = AuditLog( + id=self._storage.generate_id("aud"), + tenant_id=tenant_id, + event_type="link.accessed", + link_id=link_id, + document_id=document_id, + actor=actor or "anonymous", + ip_address=ip_address, + device=device, + location=location, + timestamp=self._now(), + ) + await self._storage.save_audit_log(log) \ No newline at end of file diff --git a/migrations/0006_add_audit_logs.py b/migrations/0006_add_audit_logs.py new file mode 100644 index 0000000..24a23cd --- /dev/null +++ b/migrations/0006_add_audit_logs.py @@ -0,0 +1,25 @@ +"""Add audit_logs table for link creation and access tracking.""" + +from yoyo import step + +steps = [ + step( + """ + CREATE TABLE audit_logs ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL, + event_type TEXT NOT NULL CHECK ( + event_type IN ('link.created', 'link.accessed') + ), + link_id TEXT NOT NULL REFERENCES share_links(id) ON DELETE CASCADE, + document_id TEXT NOT NULL REFERENCES documents(id) ON DELETE CASCADE, + actor TEXT NOT NULL, + ip_address TEXT NULL, + device TEXT NULL, + location TEXT NULL, + timestamp TIMESTAMP WITHOUT TIME ZONE NOT NULL + ) + """, + "DROP TABLE audit_logs" + ), +] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 479d7d0..5158bdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ pyjwt = "^2.11.0" asyncpg = "^0.31.0" python-multipart = "^0.0.22" httpx = {extras = ["standard"], version = "^0.28.1"} +user-agents = "^2.2.0" uvicorn = {extras = ["standard"], version = "^0.41.0"} yoyo-migrations = "^9.0.0" boto3 = "^1.42.73"