diff --git a/fastpubsub/api/app.py b/fastpubsub/api/app.py index 4770983..2da6e5a 100644 --- a/fastpubsub/api/app.py +++ b/fastpubsub/api/app.py @@ -1,3 +1,5 @@ +"""FastAPI application setup and configuration.""" + from fastapi import FastAPI, Request, status from fastapi.responses import ORJSONResponse from prometheus_fastapi_instrumentator import Instrumentator @@ -36,6 +38,14 @@ def create_app() -> FastAPI: + """Create and configure the FastAPI application. + + Sets up the complete application including middleware, exception handlers, + routers, and monitoring instrumentation. + + Returns: + Configured FastAPI application instance. + """ app = FastAPI( title="fastpubsub", description="Simple pubsub system based on FastAPI and PostgreSQL.", @@ -49,22 +59,77 @@ def create_app() -> FastAPI: # Add exception handlers @app.exception_handler(AlreadyExistsError) def already_exists_exception_handler(request: Request, exc: AlreadyExistsError): + """Handle AlreadyExistsError exceptions. + + Returns a 409 Conflict response when attempting to create resources that already exist. + + Args: + request: The incoming HTTP request. + exc: The AlreadyExistsError exception. + + Returns: + JSON error response with 409 status code. + """ return _create_error_response(models.GenericError, status.HTTP_409_CONFLICT, exc) @app.exception_handler(NotFoundError) def not_found_exception_handler(request: Request, exc: NotFoundError): + """Handle NotFoundError exceptions. + + Returns a 404 Not Found response when requesting non-existent resources. + + Args: + request: The incoming HTTP request. + exc: The NotFoundError exception. + + Returns: + JSON error response with 404 status code. + """ return _create_error_response(models.GenericError, status.HTTP_404_NOT_FOUND, exc) @app.exception_handler(ServiceUnavailable) def service_unavailable_exception_handler(request: Request, exc: ServiceUnavailable): + """Handle ServiceUnavailable exceptions. + + Returns a 503 Service Unavailable response when services are unavailable. + + Args: + request: The incoming HTTP request. + exc: The ServiceUnavailable exception. + + Returns: + JSON error response with 503 status code. + """ return _create_error_response(models.GenericError, status.HTTP_503_SERVICE_UNAVAILABLE, exc) @app.exception_handler(InvalidClient) def invalid_client_exception_handler(request: Request, exc: InvalidClient): + """Handle InvalidClient exceptions. + + Returns a 401 Unauthorized response when client authentication fails. + + Args: + request: The incoming HTTP request. + exc: The InvalidClient exception. + + Returns: + JSON error response with 401 status code. + """ return _create_error_response(models.GenericError, status.HTTP_401_UNAUTHORIZED, exc) @app.exception_handler(InvalidClientToken) def invalid_client_token_exception_handler(request: Request, exc: InvalidClientToken): + """Handle InvalidClientToken exceptions. + + Returns a 403 Forbidden response when client token is invalid or expired. + + Args: + request: The incoming HTTP request. + exc: The InvalidClientToken exception. + + Returns: + JSON error response with 403 status code. + """ return _create_error_response(models.GenericError, status.HTTP_403_FORBIDDEN, exc) # Add routers diff --git a/fastpubsub/api/helpers.py b/fastpubsub/api/helpers.py index 9549379..b232051 100644 --- a/fastpubsub/api/helpers.py +++ b/fastpubsub/api/helpers.py @@ -1,8 +1,19 @@ +"""Helper functions for API responses and error handling.""" + from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse def _create_error_response(model_class, status_code: int, exc: Exception): - """Helper to create error responses.""" + """Create a standardized error response JSON object. + + Args: + model_class: Pydantic model class for the error response. + status_code: HTTP status code for the error. + exc: Exception instance containing the error message. + + Returns: + JSONResponse with formatted error content and appropriate status code. + """ response = jsonable_encoder(model_class(detail=exc.args[0])) return JSONResponse(status_code=status_code, content=response) diff --git a/fastpubsub/api/middlewares.py b/fastpubsub/api/middlewares.py index 1f1b43a..427fa73 100644 --- a/fastpubsub/api/middlewares.py +++ b/fastpubsub/api/middlewares.py @@ -1,3 +1,5 @@ +"""HTTP middleware for request logging and monitoring.""" + import time from uuid import uuid7 @@ -9,6 +11,26 @@ async def log_requests(request: Request, call_next): + """Middleware to log HTTP requests and responses with timing and request IDs. + + This middleware: + - Generates a unique request ID for tracking + - Logs request details at the start + - Measures processing time + - Logs response details including status code and timing + - Adds request ID header to response + - Handles and logs any exceptions during request processing + + Args: + request: The incoming FastAPI request. + call_next: The next middleware or endpoint to call. + + Returns: + The processed HTTP response. + + Raises: + Exception: Re-raises any exceptions encountered during processing. + """ start_time = time.perf_counter() request_id = str(uuid7()) logger.info( diff --git a/fastpubsub/api/routers/clients.py b/fastpubsub/api/routers/clients.py index 21303b0..5db26cf 100644 --- a/fastpubsub/api/routers/clients.py +++ b/fastpubsub/api/routers/clients.py @@ -1,3 +1,5 @@ +"""API endpoints for client management operations.""" + import uuid from typing import Annotated @@ -18,6 +20,22 @@ async def create_client( data: models.CreateClient, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("clients", "create"))], ): + """Create a new client with specified name and scopes. + + Creates a new authorized client that can access the pub/sub API + based on their granted scopes. Returns the client ID and generated secret. + + Args: + data: Client creation data including name, scopes, and active status. + token: Decoded client token with 'clients:create' scope. + + Returns: + CreateClientResult containing the new client ID and secret. + + Raises: + AlreadyExistsError: If a client with the same ID already exists. + InvalidClient: If the requesting client lacks 'clients:create' scope. + """ return await services.create_client(data) @@ -32,6 +50,22 @@ async def get_client( id: uuid.UUID, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("clients", "read"))], ): + """Retrieve a client by ID. + + Returns the full details of an existing client including ID, name, + scopes, status, and timestamps. + + Args: + id: UUID of the client to retrieve. + token: Decoded client token with 'clients:read' scope. + + Returns: + Client model with full client details. + + Raises: + NotFoundError: If no client with the given ID exists. + InvalidClient: If the requesting client lacks 'clients:read' scope. + """ return await services.get_client(id) @@ -47,6 +81,23 @@ async def update_client( data: models.UpdateClient, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("clients", "update"))], ): + """Update an existing client's name, scopes, or active status. + + Modifies the properties of an existing client. Only the fields + provided in the update data will be modified. + + Args: + id: UUID of the client to update. + data: Updated client data including name, scopes, and/or active status. + token: Decoded client token with 'clients:update' scope. + + Returns: + Client model with updated details. + + Raises: + NotFoundError: If no client with the given ID exists. + InvalidClient: If the requesting client lacks 'clients:update' scope. + """ return await services.update_client(id, data) @@ -61,6 +112,21 @@ async def list_client( offset: int = Query(default=0, ge=0), limit: int = Query(default=10, ge=1, le=100), ): + """List clients with pagination support. + + Returns a paginated list of all clients in the system. + + Args: + token: Decoded client token with 'clients:read' scope. + offset: Number of items to skip (for pagination). + limit: Maximum number of items to return (1-100). + + Returns: + ListClientAPI containing the list of clients. + + Raises: + InvalidClient: If the requesting client lacks 'clients:read' scope. + """ clients = await services.list_client(offset, limit) return models.ListClientAPI(data=clients) @@ -75,6 +141,18 @@ async def delete_client( id: uuid.UUID, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("clients", "delete"))], ): + """Delete a client by ID. + + Permanently removes a client from the system. This action cannot be undone. + + Args: + id: UUID of the client to delete. + token: Decoded client token with 'clients:delete' scope. + + Raises: + NotFoundError: If no client with the given ID exists. + InvalidClient: If the requesting client lacks 'clients:delete' scope. + """ await services.delete_client(id) @@ -85,4 +163,20 @@ async def delete_client( summary="Issue a new client token", ) async def issue_client_token(data: models.IssueClientToken): + """Issue a new JWT access token for a client. + + Generates a new access token that the client can use for authentication + in subsequent API requests. The token includes the client's scopes + and has an expiration time. + + Args: + data: Client credentials including ID and secret for authentication. + + Returns: + ClientToken containing the access token, type, expiration, and scopes. + + Raises: + InvalidClient: If client ID or secret is invalid. + ServiceUnavailable: If token generation service is unavailable. + """ return await services.issue_jwt_client_token(client_id=data.client_id, client_secret=data.client_secret) diff --git a/fastpubsub/api/routers/monitoring.py b/fastpubsub/api/routers/monitoring.py index 0828cf2..26c8d73 100644 --- a/fastpubsub/api/routers/monitoring.py +++ b/fastpubsub/api/routers/monitoring.py @@ -1,3 +1,5 @@ +"""API endpoints for monitoring and health check operations.""" + from fastapi import APIRouter, status from fastpubsub import models, services @@ -13,6 +15,15 @@ summary="Liveness probe", ) async def liveness_probe(): + """Check if the application is alive. + + Simple liveness check that always returns "alive" status. + Used by Kubernetes and other orchestration systems to determine + if the application process is running. + + Returns: + HealthCheck model with status "alive". + """ return models.HealthCheck(status="alive") @@ -24,6 +35,18 @@ async def liveness_probe(): summary="Readiness probe", ) async def readiness_probe(): + """Check if the application is ready to serve traffic. + + Comprehensive health check that verifies database connectivity. + Returns "ready" status only if all critical dependencies are available. + Used by Kubernetes to determine if the application can handle requests. + + Returns: + HealthCheck model with status "ready". + + Raises: + ServiceUnavailable: If database connection fails. + """ try: is_db_ok = await services.database_ping() if not is_db_ok: diff --git a/fastpubsub/api/routers/subscriptions.py b/fastpubsub/api/routers/subscriptions.py index 708c43d..c787202 100644 --- a/fastpubsub/api/routers/subscriptions.py +++ b/fastpubsub/api/routers/subscriptions.py @@ -1,3 +1,5 @@ +"""API endpoints for subscription management and message operations.""" + from typing import Annotated from uuid import UUID @@ -19,6 +21,23 @@ async def create_subscription( data: models.CreateSubscription, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "create"))], ): + """Create a new subscription to a topic. + + Creates a subscription that defines how messages from a topic should be consumed, + including filtering, delivery attempts, and backoff configuration. + + Args: + data: Subscription creation data including ID, topic ID, filter, and delivery settings. + token: Decoded client token with 'subscriptions:create' scope. + + Returns: + Subscription model with the created subscription details. + + Raises: + AlreadyExistsError: If a subscription with the same ID already exists. + NotFoundError: If the specified topic doesn't exist. + InvalidClient: If the requesting client lacks 'subscriptions:create' scope. + """ return await services.create_subscription(data) @@ -33,6 +52,22 @@ async def get_subscription( id: str, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "read"))], ): + """Retrieve a subscription by ID. + + Returns the full details of an existing subscription including ID, topic ID, + filter configuration, delivery attempts, and backoff settings. + + Args: + id: String ID of the subscription to retrieve. + token: Decoded client token with 'subscriptions:read' scope. + + Returns: + Subscription model with full subscription details. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:read' scope. + """ return await services.get_subscription(id) @@ -47,6 +82,21 @@ async def list_subscription( offset: int = Query(default=0, ge=0), limit: int = Query(default=10, ge=1, le=100), ): + """List subscriptions with pagination support. + + Returns a paginated list of all subscriptions in the system. + + Args: + token: Decoded client token with 'subscriptions:read' scope. + offset: Number of items to skip (for pagination). + limit: Maximum number of items to return (1-100). + + Returns: + ListSubscriptionAPI containing the list of subscriptions. + + Raises: + InvalidClient: If the requesting client lacks 'subscriptions:read' scope. + """ subscriptions = await services.list_subscription(offset, limit) return models.ListSubscriptionAPI(data=subscriptions) @@ -61,6 +111,19 @@ async def delete_subscription( id: str, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "delete"))], ): + """Delete a subscription by ID. + + Permanently removes a subscription from the system. This action cannot be undone + and will also remove all messages associated with the subscription. + + Args: + id: String ID of the subscription to delete. + token: Decoded client token with 'subscriptions:delete' scope. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:delete' scope. + """ await services.delete_subscription(id) @@ -77,6 +140,24 @@ async def consume_messages( token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "consume"))], batch_size: int = Query(default=10, ge=1, le=100), ): + """Consume messages from a subscription. + + Retrieves messages from the subscription queue that are available for processing. + Messages are locked to the consumer to prevent duplicate processing. + + Args: + id: String ID of the subscription to consume from. + consumer_id: Unique identifier for the consumer instance. + token: Decoded client token with 'subscriptions:consume' scope. + batch_size: Number of messages to retrieve (1-100). + + Returns: + ListMessageAPI containing the available messages. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:consume' scope. + """ subscription = await get_subscription(id, token) messages = await services.consume_messages( subscription_id=subscription.id, consumer_id=consumer_id, batch_size=batch_size @@ -95,6 +176,20 @@ async def ack_messages( data: list[UUID], token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "consume"))], ): + """Acknowledge successful processing of messages. + + Marks messages as successfully processed, removing them from the queue. + Acknowledged messages will not be delivered again. + + Args: + id: String ID of the subscription. + data: List of message UUIDs to acknowledge. + token: Decoded client token with 'subscriptions:consume' scope. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:consume' scope. + """ subscription = await get_subscription(id, token) await services.ack_messages(subscription_id=subscription.id, message_ids=data) @@ -110,6 +205,20 @@ async def nack_messages( data: list[UUID], token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "consume"))], ): + """Negative acknowledgment of message processing failure. + + Marks messages as failed, making them available for redelivery. + The message will be redelivered according to the subscription's backoff configuration. + + Args: + id: String ID of the subscription. + data: List of message UUIDs to negatively acknowledge. + token: Decoded client token with 'subscriptions:consume' scope. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:consume' scope. + """ subscription = await get_subscription(id, token) await services.nack_messages(subscription_id=subscription.id, message_ids=data) @@ -127,6 +236,24 @@ async def list_dlq( offset: int = Query(default=0, ge=0), limit: int = Query(default=10, ge=1, le=100), ): + """List messages in the dead letter queue. + + Retrieves messages that have failed delivery after exceeding the maximum + number of delivery attempts and have been moved to the DLQ. + + Args: + id: String ID of the subscription. + token: Decoded client token with 'subscriptions:consume' scope. + offset: Number of items to skip (for pagination). + limit: Maximum number of items to return (1-100). + + Returns: + ListMessageAPI containing the DLQ messages. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:consume' scope. + """ subscription = await get_subscription(id, token) messages = await services.list_dlq_messages(subscription_id=subscription.id, offset=offset, limit=limit) return models.ListMessageAPI(data=messages) @@ -143,6 +270,20 @@ async def reprocess_dlq( data: list[UUID], token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "consume"))], ): + """Move dead letter queue messages back to active processing. + + Reprocesses messages from the DLQ by moving them back to the main queue + for another attempt at delivery. + + Args: + id: String ID of the subscription. + data: List of message UUIDs to reprocess. + token: Decoded client token with 'subscriptions:consume' scope. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:consume' scope. + """ subscription = await get_subscription(id, token) await services.reprocess_dlq_messages(subscription_id=subscription.id, message_ids=data) @@ -158,5 +299,20 @@ async def subscription_metrics( id: str, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("subscriptions", "read"))], ): + """Get metrics and statistics for a subscription. + + Returns counts of messages in different states for monitoring and analysis. + + Args: + id: String ID of the subscription. + token: Decoded client token with 'subscriptions:read' scope. + + Returns: + SubscriptionMetrics containing message counts by state. + + Raises: + NotFoundError: If no subscription with the given ID exists. + InvalidClient: If the requesting client lacks 'subscriptions:read' scope. + """ subscription = await get_subscription(id, token) return await services.subscription_metrics(subscription_id=subscription.id) diff --git a/fastpubsub/api/routers/topics.py b/fastpubsub/api/routers/topics.py index 39262a4..75a586f 100644 --- a/fastpubsub/api/routers/topics.py +++ b/fastpubsub/api/routers/topics.py @@ -1,3 +1,5 @@ +"""API endpoints for topic management and message publishing operations.""" + from typing import Annotated, Any from fastapi import APIRouter, Depends, Query, status @@ -18,6 +20,22 @@ async def create_topic( data: models.CreateTopic, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("topics", "create"))], ): + """Create a new topic in the pub/sub system. + + Creates a topic that can be used to organize and publish messages. + Subscriptions can be created to consume messages from topics. + + Args: + data: Topic creation data including the unique topic ID. + token: Decoded client token with 'topics:create' scope. + + Returns: + Topic model with the created topic details. + + Raises: + AlreadyExistsError: If a topic with the same ID already exists. + InvalidClient: If the requesting client lacks 'topics:create' scope. + """ return await services.create_topic(data) @@ -31,6 +49,21 @@ async def create_topic( async def get_topic( id: str, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("topics", "read"))] ): + """Retrieve a topic by ID. + + Returns the full details of an existing topic including ID and creation timestamp. + + Args: + id: String ID of the topic to retrieve. + token: Decoded client token with 'topics:read' scope. + + Returns: + Topic model with full topic details. + + Raises: + NotFoundError: If no topic with the given ID exists. + InvalidClient: If the requesting client lacks 'topics:read' scope. + """ return await services.get_topic(id) @@ -45,6 +78,21 @@ async def list_topic( offset: int = Query(default=0, ge=0), limit: int = Query(default=10, ge=1, le=100), ): + """List topics with pagination support. + + Returns a paginated list of all topics in the system. + + Args: + token: Decoded client token with 'topics:read' scope. + offset: Number of items to skip (for pagination). + limit: Maximum number of items to return (1-100). + + Returns: + ListTopicAPI containing the list of topics. + + Raises: + InvalidClient: If the requesting client lacks 'topics:read' scope. + """ topics = await services.list_topic(offset, limit) return models.ListTopicAPI(data=topics) @@ -58,6 +106,19 @@ async def list_topic( async def delete_topic( id: str, token: Annotated[models.DecodedClientToken, Depends(services.require_scope("topics", "delete"))] ): + """Delete a topic by ID. + + Permanently removes a topic from the system. This action cannot be undone + and will also remove all subscriptions and messages associated with the topic. + + Args: + id: String ID of the topic to delete. + token: Decoded client token with 'topics:delete' scope. + + Raises: + NotFoundError: If no topic with the given ID exists. + InvalidClient: If the requesting client lacks 'topics:delete' scope. + """ await services.delete_topic(id) @@ -72,5 +133,22 @@ async def publish_messages( data: list[dict[str, Any]], token: Annotated[models.DecodedClientToken, Depends(services.require_scope("topics", "publish"))], ): + """Publish messages to a topic. + + Publishes one or more messages to a topic, making them available + for consumption by subscriptions to that topic. + + Args: + id: String ID of the topic to publish to. + data: List of message dictionaries to publish. + token: Decoded client token with 'topics:publish' scope. + + Returns: + Integer count of messages successfully published. + + Raises: + NotFoundError: If no topic with the given ID exists. + InvalidClient: If the requesting client lacks 'topics:publish' scope. + """ topic = await services.get_topic(id) return await services.publish_messages(topic_id=topic.id, messages=data) diff --git a/fastpubsub/api/server.py b/fastpubsub/api/server.py index d905a27..5460c35 100644 --- a/fastpubsub/api/server.py +++ b/fastpubsub/api/server.py @@ -1,24 +1,56 @@ +"""Gunicorn server configuration and startup for fastpubsub application.""" + from gunicorn.app.base import BaseApplication from fastpubsub.config import settings class CustomGunicornApp(BaseApplication): + """Custom Gunicorn application for running the FastAPI app. + + Extends BaseApplication to provide custom configuration and loading + of the FastAPI application for production deployment. + """ + def __init__(self, app, options=None): + """Initialize the custom Gunicorn application. + + Args: + app: The FastAPI application instance to run. + options: Optional dictionary of Gunicorn configuration options. + """ self.options = options or {} self.application = app super().__init__() def load_config(self): + """Load configuration settings from options dictionary. + + Sets Gunicorn configuration values from the options provided during initialization. + Only applies settings that are valid Gunicorn configuration keys. + """ for key, value in self.options.items(): if key in self.cfg.settings and value is not None: self.cfg.set(key, value) def load(self): + """Load the FastAPI application. + + Returns: + The FastAPI application instance to be served by Gunicorn. + """ return self.application def run_server(app): + """Start the Gunicorn server with the FastAPI application. + + Configures and starts the production HTTP server using Gunicorn + with Uvicorn workers for running the FastAPI application. + + Args: + app: The FastAPI application instance to serve. + """ options = { "bind": f"{settings.api_host}:{settings.api_port}", "workers": settings.api_num_workers, diff --git a/fastpubsub/config.py b/fastpubsub/config.py index 5fccfa1..f6c1915 100644 --- a/fastpubsub/config.py +++ b/fastpubsub/config.py @@ -1,3 +1,5 @@ +"""Configuration settings for fastpubsub application.""" + from enum import StrEnum from pydantic import Field, field_validator, model_validator @@ -6,6 +8,11 @@ class LogLevel(StrEnum): + """Enumeration of available logging levels. + + Used to specify the verbosity of logging output throughout the application. + """ + debug = "debug" info = "info" warning = "warning" @@ -14,6 +21,15 @@ class LogLevel(StrEnum): class Settings(BaseSettings): + """Application settings configuration. + + Manages all configuration parameters for the fastpubsub application, + including database, logging, subscription, API, worker, and authentication settings. + + Settings are loaded from environment variables with the 'fastpubsub_' prefix + or from a .env file in the application directory. + """ + # database database_url: str database_echo: bool = False @@ -53,12 +69,33 @@ class Settings(BaseSettings): @field_validator("database_url") def validate_database_url_format(cls, v: str): + """Validate database URL format. + + Args: + v: The database URL string to validate. + + Returns: + The validated database URL. + + Raises: + ValueError: If URL doesn't start with 'postgresql+psycopg://'. + """ if not v.startswith("postgresql+psycopg://"): raise ValueError("must start with 'postgresql+psycopg://'") return v @model_validator(mode="after") def check_subscription_backoff_order(self) -> "Settings": + """Validate subscription backoff timing configuration. + + Ensures that max backoff seconds is greater than or equal to min backoff seconds. + + Returns: + The validated Settings instance. + + Raises: + ValueError: If max backoff seconds is less than min backoff seconds. + """ if self.subscription_backoff_max_seconds < self.subscription_backoff_min_seconds: raise ValueError( "subscription_backoff_max_seconds must be greater than or equal to subscription_backoff_min_seconds" diff --git a/fastpubsub/database.py b/fastpubsub/database.py index 784bb20..a88682d 100644 --- a/fastpubsub/database.py +++ b/fastpubsub/database.py @@ -1,3 +1,5 @@ +"""Database models and utilities for fastpubsub application.""" + from pathlib import Path import sqlalchemy as sa @@ -23,11 +25,26 @@ class Base(DeclarativeBase): + """Base declarative class for all database models. + + Provides common functionality for all ORM models in the application. + """ + def to_dict(self): + """Convert model instance to dictionary. + + Returns: + Dictionary mapping column names to their values. + """ return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs} class Topic(Base): + """Database model representing a topic in the pub/sub system. + + Topics are used to organize messages and subscriptions. + """ + id = sa.Column(sa.Text, primary_key=True) created_at = sa.Column(sa.DateTime(timezone=True), nullable=False) @@ -38,6 +55,12 @@ def __repr__(self): class Subscription(Base): + """Database model representing a subscription to a topic. + + Subscriptions define how messages from a topic should be consumed, + including filtering, delivery attempts, and backoff configuration. + """ + id = sa.Column(sa.Text, primary_key=True) topic_id = sa.Column(sa.Text, nullable=False) filter = sa.Column(postgresql.JSONB, nullable=False, default={}) @@ -53,6 +76,11 @@ def __repr__(self): class SubscriptionMessage(Base): + """Database model representing a message in a subscription's queue. + + Tracks message delivery status, attempts, and processing state. + """ + id = sa.Column(postgresql.UUID, primary_key=True) subscription_id = sa.Column(sa.Text, nullable=False) payload = sa.Column(postgresql.JSONB, nullable=False) @@ -71,6 +99,11 @@ def __repr__(self): class Client(Base): + """Database model representing an authorized client of the pub/sub system. + + Clients can be granted scopes to perform specific operations. + """ + id = sa.Column(postgresql.UUID, primary_key=True) name = sa.Column(sa.Text, nullable=False) scopes = sa.Column(sa.Text, nullable=False) @@ -87,6 +120,17 @@ def __repr__(self): async def run_migrations(command_type: str = "upgrade", revision: str = "head") -> None: + """Run database migrations using Alembic. + + Executes database schema migrations to update or revert the database structure. + + Args: + command_type: Migration command to execute ('upgrade' or 'downgrade'). + revision: Alembic revision to apply ('head' for latest, specific revision ID, etc.). + + Raises: + Exception: If migration command fails. + """ parent_path = Path(__file__).parents[1] script_location = parent_path.joinpath(Path("migrations")) ini_location = parent_path.joinpath(Path("alembic.ini")) @@ -113,8 +157,24 @@ async def run_migrations(command_type: str = "upgrade", revision: str = "head") def is_unique_violation(exc: IntegrityError) -> bool: + """Check if an IntegrityError is a unique constraint violation. + + Args: + exc: The IntegrityError exception to check. + + Returns: + True if the exception is a unique constraint violation, False otherwise. + """ return "psycopg.errors.UniqueViolation" in exc.args[0] def is_foreign_key_violation(exc: IntegrityError) -> bool: + """Check if an IntegrityError is a foreign key constraint violation. + + Args: + exc: The IntegrityError exception to check. + + Returns: + True if the exception is a foreign key constraint violation, False otherwise. + """ return "psycopg.errors.ForeignKeyViolation" in exc.args[0] diff --git a/fastpubsub/exceptions.py b/fastpubsub/exceptions.py index 3ef4d37..143a1d5 100644 --- a/fastpubsub/exceptions.py +++ b/fastpubsub/exceptions.py @@ -1,18 +1,51 @@ +"""Custom exception classes for fastpubsub application.""" + + class NotFoundError(Exception): + """Exception raised when a requested resource is not found. + + This exception is typically used when trying to access or manipulate + database entities that don't exist. + """ + pass class AlreadyExistsError(Exception): + """Exception raised when attempting to create a resource that already exists. + + This exception is used for unique constraint violations where + a duplicate resource creation is attempted. + """ + pass class ServiceUnavailable(Exception): + """Exception raised when a service operation cannot be completed. + + This exception is used when external dependencies or services + are not available or operations fail unexpectedly. + """ + pass class InvalidClient(Exception): + """Exception raised when a client is not authorized or valid. + + This exception is used when client authentication fails or + a client lacks necessary permissions. + """ + pass class InvalidClientToken(Exception): + """Exception raised when a client token is invalid or expired. + + This exception is used when JWT token validation fails or + the token format/claims are incorrect. + """ + pass diff --git a/fastpubsub/logger.py b/fastpubsub/logger.py index 821fb92..d0ac03c 100644 --- a/fastpubsub/logger.py +++ b/fastpubsub/logger.py @@ -1,3 +1,5 @@ +"""Logging utilities for fastpubsub application.""" + import logging from pythonjsonlogger.json import JsonFormatter @@ -6,10 +8,26 @@ def get_log_level(level: str) -> int: + """Convert string log level to logging module constant. + + Args: + level: Log level as a string (debug, info, warning, error, critical). + + Returns: + Integer constant for the log level from the logging module. + + Raises: + AttributeError: If the level string is not valid. + """ return getattr(logging, level.upper()) def get_console_handler() -> logging.StreamHandler: + """Create and configure a console handler with JSON formatter. + + Returns: + Configured StreamHandler with JSON formatter for console output. + """ formatter = JsonFormatter(settings.log_formatter) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) @@ -17,6 +35,14 @@ def get_console_handler() -> logging.StreamHandler: def get_logger(name: str) -> logging.Logger: + """Create and configure a logger with the specified name. + + Args: + name: Name for the logger, typically __name__ from the calling module. + + Returns: + Configured logger instance with appropriate log level and handlers. + """ logger = logging.getLogger(name) log_level = get_log_level(settings.log_level) logger.setLevel(log_level) diff --git a/fastpubsub/main.py b/fastpubsub/main.py index 248605e..8934007 100644 --- a/fastpubsub/main.py +++ b/fastpubsub/main.py @@ -1,3 +1,5 @@ +"""Command-line interface for fastpubsub application.""" + import asyncio from typing import Annotated @@ -16,7 +18,17 @@ async def _log_command_execution_async(command_name: str, func, *args, **kwargs): - """Helper to log async command execution with start and finish messages.""" + """Helper to log async command execution with start and finish messages. + + Args: + command_name: Name of the command being executed. + func: Async function to execute. + *args: Positional arguments to pass to the function. + **kwargs: Keyword arguments to pass to the function. + + Returns: + The result of the executed function. + """ logger.info(f"Starting {command_name} command") result = await func(*args, **kwargs) logger.info(f"Finishing {command_name} command") @@ -25,6 +37,11 @@ async def _log_command_execution_async(command_name: str, func, *args, **kwargs) @cli.command("db-migrate") def run_migrations_command() -> None: + """Run database migrations to upgrade to the latest schema. + + Executes all pending Alembic migrations to update the database schema + to the latest version. This is typically used during application deployment. + """ asyncio.run( _log_command_execution_async("db-migrate", run_migrations, command_type="upgrade", revision="head") ) @@ -32,6 +49,11 @@ def run_migrations_command() -> None: @cli.command("server") def run_server_command() -> None: + """Start the FastAPI server. + + Launches the HTTP server that serves the fastpubsub API endpoints. + This is a long-running command that will continue until stopped. + """ # Server is a long-running command, so we only log the start logger.info("Starting server command") run_server(app) @@ -39,6 +61,12 @@ def run_server_command() -> None: @cli.command("cleanup_acked_messages") def run_cleanup_acked_messages() -> None: + """Remove acknowledged messages older than the configured threshold. + + Cleans up message history by removing messages that have been acknowledged + and are older than the configured time threshold. Helps prevent database + bloat and improve performance. + """ asyncio.run( _log_command_execution_async( "cleanup_acked_messages", @@ -50,6 +78,12 @@ def run_cleanup_acked_messages() -> None: @cli.command("cleanup_stuck_messages") def run_cleanup_stuck_messages() -> None: + """Unlock messages that have been locked for too long. + + Releases locks on messages that have been locked beyond the timeout period, + making them available for consumption again. This helps recover from + consumer failures or crashes. + """ asyncio.run( _log_command_execution_async( "cleanup_stuck_messages", @@ -61,6 +95,11 @@ def run_cleanup_stuck_messages() -> None: @cli.command("generate_secret_key") def run_generate_secret_key() -> None: + """Generate a new random secret key for client authentication. + + Creates a cryptographically secure random string that can be used as + a client secret for JWT token generation and validation. + """ secret = generate_secret() typer.echo(f"new_secret={secret}") @@ -71,6 +110,16 @@ def run_create_client( scopes: Annotated[str, typer.Argument(help="The client scopes.")] = "*", is_active: Annotated[bool, typer.Argument(help="The flag to enable or disable client.")] = True, ) -> None: + """Create a new client with the specified name and scopes. + + Creates a new authorized client in the system that can access the + pub/sub API endpoints based on their granted scopes. + + Args: + name: Human-readable name for the client. + scopes: Space-separated list of permissions/scopes granted to the client. + is_active: Whether the client is initially active and can authenticate. + """ client_result = asyncio.run( create_client(data=CreateClient(name=name, scopes=scopes, is_active=is_active)) ) diff --git a/fastpubsub/models.py b/fastpubsub/models.py index f08bdf1..96a856b 100644 --- a/fastpubsub/models.py +++ b/fastpubsub/models.py @@ -1,3 +1,5 @@ +"""Pydantic models for fastpubsub API.""" + import uuid from datetime import datetime from typing import Annotated @@ -11,23 +13,65 @@ class GenericError(BaseModel): + """Generic error response model. + + Used for consistent error responses across the API. + + Attributes: + detail: Human-readable error message. + """ + detail: str class CreateTopic(BaseModel): + """Model for creating a new topic. + + Defines the required fields for topic creation. + + Attributes: + id: Unique identifier for the topic (alphanumeric with -._ chars, max 128 chars). + """ + id: str = Field(..., pattern=regex_for_id, max_length=128) class Topic(BaseModel): + """Model representing a topic in the pub/sub system. + + Attributes: + id: Unique identifier for the topic. + created_at: Timestamp when the topic was created. + """ + id: str created_at: datetime class ListTopicAPI(BaseModel): + """Model for paginated topic list response. + + Attributes: + data: List of topic objects. + """ + data: list[Topic] class CreateSubscription(BaseModel): + """Model for creating a new subscription. + + Defines the required fields for subscription creation with delivery and backoff configuration. + + Attributes: + id: Unique identifier for the subscription (alphanumeric with -._ chars, max 128 chars). + topic_id: ID of the topic to subscribe to. + filter: Optional JSON filter to apply to messages. + max_delivery_attempts: Maximum number of delivery attempts (defaults to settings). + backoff_min_seconds: Minimum backoff time between attempts (defaults to settings). + backoff_max_seconds: Maximum backoff time between attempts (defaults to settings). + """ + id: str = Field(..., pattern=regex_for_id, max_length=128) topic_id: str = Field(..., pattern=regex_for_id, max_length=128) filter: dict | None = None @@ -38,11 +82,30 @@ class CreateSubscription(BaseModel): @field_validator("filter") @classmethod def sanitize_filter_field(cls, v): - """Sanitize filter to prevent SQL and XSS injection attacks.""" + """Sanitize filter to prevent SQL and XSS injection attacks. + + Args: + v: The filter dictionary to sanitize. + + Returns: + Sanitized filter dictionary. + """ return sanitize_filter(v) class Subscription(BaseModel): + """Model representing a subscription in the pub/sub system. + + Attributes: + id: Unique identifier for the subscription. + topic_id: ID of the subscribed topic. + filter: JSON filter applied to messages. + max_delivery_attempts: Maximum delivery attempts for failed messages. + backoff_min_seconds: Minimum backoff time between attempts. + backoff_max_seconds: Maximum backoff time between attempts. + created_at: Timestamp when the subscription was created. + """ + id: str topic_id: str filter: dict | None @@ -53,10 +116,26 @@ class Subscription(BaseModel): class ListSubscriptionAPI(BaseModel): + """Model for paginated subscription list response. + + Attributes: + data: List of subscription objects. + """ + data: list[Subscription] class Message(BaseModel): + """Model representing a message in the pub/sub system. + + Attributes: + id: Unique identifier for the message. + subscription_id: ID of the subscription this message belongs to. + payload: JSON payload of the message. + delivery_attempts: Number of delivery attempts made. + created_at: Timestamp when the message was created. + """ + id: uuid.UUID subscription_id: str payload: dict @@ -65,10 +144,28 @@ class Message(BaseModel): class ListMessageAPI(BaseModel): + """Model for paginated message list response. + + Attributes: + data: List of message objects. + """ + data: list[Message] class SubscriptionMetrics(BaseModel): + """Model for subscription metrics and statistics. + + Provides counts of messages in different states for a subscription. + + Attributes: + subscription_id: ID of the subscription. + available: Number of messages available for consumption. + delivered: Number of messages delivered but not yet acked. + acked: Number of messages successfully acknowledged. + dlq: Number of messages in dead letter queue. + """ + subscription_id: str available: int delivered: int @@ -77,16 +174,43 @@ class SubscriptionMetrics(BaseModel): class HealthCheck(BaseModel): + """Model for application health check response. + + Attributes: + status: Health status string (e.g., "ok"). + """ + status: str class CreateClient(BaseModel): + """Model for creating a new client. + + Defines the required fields for client creation with authentication configuration. + + Attributes: + name: Human-readable name for the client. + scopes: Space-separated list of permissions/scopes. + is_active: Whether the client is active and can authenticate. + """ + name: Annotated[str, StringConstraints(min_length=1, strip_whitespace=True)] scopes: Annotated[str, StringConstraints(min_length=1, strip_whitespace=True)] is_active: bool = True @field_validator("scopes") def validate_scopes(cls, v: str): + """Validate that all scopes are among the allowed values. + + Args: + v: Space-separated scopes string to validate. + + Returns: + The validated scopes string. + + Raises: + ValueError: If any scope is invalid. + """ valid_scopes = ( "*", "topics:create", @@ -112,11 +236,32 @@ def validate_scopes(cls, v: str): class CreateClientResult(BaseModel): + """Model for client creation response. + + Contains the newly created client's ID and generated secret. + + Attributes: + id: Unique identifier of the created client. + secret: Generated secret key for the client. + """ + id: uuid.UUID secret: str class Client(BaseModel): + """Model representing an authorized client. + + Attributes: + id: Unique identifier for the client. + name: Human-readable name for the client. + scopes: Space-separated list of granted permissions. + is_active: Whether the client can currently authenticate. + token_version: Version counter for token invalidation. + created_at: Timestamp when the client was created. + updated_at: Timestamp when the client was last updated. + """ + id: uuid.UUID name: str scopes: str @@ -127,10 +272,24 @@ class Client(BaseModel): class UpdateClient(CreateClient): + """Model for updating an existing client. + + Inherits all fields from CreateClient. Used for partial updates. + """ + pass class ClientToken(BaseModel): + """Model for JWT access token response. + + Attributes: + access_token: JWT access token string. + token_type: Type of token (default: "Bearer"). + expires_in: Token expiration time in seconds. + scope: Space-separated list of granted scopes. + """ + access_token: str token_type: str = "Bearer" expires_in: int @@ -138,14 +297,34 @@ class ClientToken(BaseModel): class DecodedClientToken(BaseModel): + """Model for decoded JWT token payload. + + Attributes: + client_id: ID of the client this token belongs to. + scopes: Set of granted scopes from the token. + """ + client_id: uuid.UUID scopes: set[str] class ListClientAPI(BaseModel): + """Model for paginated client list response. + + Attributes: + data: List of client objects. + """ + data: list[Client] class IssueClientToken(BaseModel): + """Model for requesting a new client token. + + Attributes: + client_id: ID of the client requesting a token. + client_secret: Secret key for client authentication. + """ + client_id: uuid.UUID client_secret: str diff --git a/fastpubsub/sanitizer.py b/fastpubsub/sanitizer.py index d6277a9..bc1828f 100644 --- a/fastpubsub/sanitizer.py +++ b/fastpubsub/sanitizer.py @@ -1,15 +1,13 @@ -""" -Sanitization utilities for preventing SQL and XSS injection attacks. -""" +"""Sanitization utilities for preventing SQL and XSS injection attacks.""" import html import re def sanitize_string(value: str) -> str: - """ - Sanitize a string value to prevent XSS attacks. + """Sanitize a string value to prevent XSS attacks. + Performs the following sanitization: - HTML entity encoding to prevent script injection - Remove null bytes and control characters @@ -17,7 +15,7 @@ def sanitize_string(value: str) -> str: value: String to sanitize Returns: - Sanitized string + Sanitized string safe for use in web contexts. """ if not isinstance(value, str): return value @@ -33,8 +31,7 @@ def sanitize_string(value: str) -> str: def validate_filter_structure(filter_dict: dict | None) -> bool: - """ - Validate that a filter has the correct structure. + """Validate that a filter has the correct structure. Expected structure: {"field_name": ["value1", "value2", ...]} - Keys must be strings @@ -45,7 +42,7 @@ def validate_filter_structure(filter_dict: dict | None) -> bool: filter_dict: Filter dictionary to validate Returns: - True if valid, False otherwise + True if valid structure, False otherwise. """ if filter_dict is None or filter_dict == {}: return True @@ -74,8 +71,7 @@ def validate_filter_structure(filter_dict: dict | None) -> bool: def sanitize_filter(filter_dict: dict | None) -> dict | None: - """ - Sanitize a filter dictionary to prevent SQL and XSS injection attacks. + """Sanitize a filter dictionary to prevent SQL and XSS injection attacks. This function: 1. Validates the filter structure @@ -86,10 +82,10 @@ def sanitize_filter(filter_dict: dict | None) -> dict | None: filter_dict: Filter dictionary to sanitize Returns: - Sanitized filter dictionary + Sanitized filter dictionary safe for database queries and web contexts. Raises: - ValueError: If filter structure is invalid + ValueError: If filter structure is invalid. """ if filter_dict is None or filter_dict == {}: return filter_dict diff --git a/fastpubsub/services/auth.py b/fastpubsub/services/auth.py index 3da5d6e..2fe98a5 100644 --- a/fastpubsub/services/auth.py +++ b/fastpubsub/services/auth.py @@ -1,3 +1,5 @@ +"""Authentication and authorization services for fastpubsub.""" + from typing import Annotated from fastapi import Depends, Request @@ -12,6 +14,20 @@ def has_scope(token_scopes: set[str], resource: str, action: str, resource_id: str | None = None) -> bool: + """Check if token scopes include the required permission. + + Evaluates if the provided token scopes grant permission to perform + the requested action on the specified resource. + + Args: + token_scopes: Set of scopes from the client's token. + resource: Resource type (e.g., 'topics', 'subscriptions', 'clients'). + action: Action to perform (e.g., 'create', 'read', 'update', 'delete', 'publish', 'consume'). + resource_id: Optional specific resource ID for fine-grained permissions. + + Returns: + True if the required scope is granted, False otherwise. + """ if "*" in token_scopes: return True @@ -27,12 +43,39 @@ def has_scope(token_scopes: set[str], resource: str, action: str, resource_id: s async def get_current_token(token: str | None = Depends(oauth2_scheme)) -> DecodedClientToken: + """Extract and decode the current client's JWT token. + + Args: + token: OAuth2 bearer token from the request header. + + Returns: + DecodedClientToken with client ID and scopes. + + Raises: + InvalidClientToken: If token is invalid or authentication fails. + """ if token is None: token = "" return await services.decode_jwt_client_token(token, auth_enabled=settings.auth_enabled) def require_scope(resource: str, action: str): + """Create a dependency that requires specific scope for API endpoints. + + Generates a FastAPI dependency that validates the incoming request + has the required scope for the specified resource and action. + + Args: + resource: Resource type (e.g., 'topics', 'subscriptions', 'clients'). + action: Action to perform (e.g., 'create', 'read', 'update', 'delete', 'publish', 'consume'). + + Returns: + FastAPI dependency function that validates scope requirements. + + Raises: + InvalidClientToken: If client lacks required scope. + """ + async def dependency(request: Request, token: Annotated[DecodedClientToken, Depends(get_current_token)]): resource_id = request.path_params.get("id") if resource_id is not None: diff --git a/fastpubsub/services/clients.py b/fastpubsub/services/clients.py index 55cd9b2..af2cf00 100644 --- a/fastpubsub/services/clients.py +++ b/fastpubsub/services/clients.py @@ -1,3 +1,5 @@ +"""Client management services for authentication and authorization.""" + import datetime import secrets import uuid @@ -25,10 +27,33 @@ def generate_secret() -> str: + """Generate a cryptographically secure random secret. + + Creates a random 32-character hexadecimal string that can be used + as a client secret for JWT token authentication. + + Returns: + Random secret string in hexadecimal format. + """ return secrets.token_hex(16) async def create_client(data: CreateClient) -> CreateClientResult: + """Create a new client with authentication credentials. + + Creates a new client in the database with a generated secret and + initializes the client with the provided configuration. + + Args: + data: Client creation data including name, scopes, and active status. + + Returns: + CreateClientResult containing the new client ID and generated secret. + + Raises: + AlreadyExistsError: If a client with the same ID already exists. + ValueError: If client data validation fails. + """ async with SessionLocal() as session: now = utc_now() secret = generate_secret() @@ -51,6 +76,19 @@ async def create_client(data: CreateClient) -> CreateClientResult: async def get_client(client_id: uuid.UUID) -> Client: + """Retrieve a client by ID. + + Fetches the full details of an existing client from the database. + + Args: + client_id: UUID of the client to retrieve. + + Returns: + Client model with full client details. + + Raises: + NotFoundError: If no client with the given ID exists. + """ async with SessionLocal() as session: db_client = await _get_entity(session, DBClient, client_id, "Client not found") @@ -58,6 +96,17 @@ async def get_client(client_id: uuid.UUID) -> Client: async def list_client(offset: int, limit: int) -> list[Client]: + """List clients with pagination support. + + Retrieves a paginated list of all clients in the system. + + Args: + offset: Number of items to skip for pagination. + limit: Maximum number of items to return. + + Returns: + List of Client models. + """ async with SessionLocal() as session: stmt = select(DBClient).order_by(DBClient.id.asc()).offset(offset).limit(limit) result = await session.execute(stmt) @@ -67,6 +116,21 @@ async def list_client(offset: int, limit: int) -> list[Client]: async def update_client(client_id: uuid.UUID, data: UpdateClient) -> Client: + """Update an existing client's properties. + + Modifies the properties of an existing client and increments + the token version to invalidate existing tokens. + + Args: + client_id: UUID of the client to update. + data: Updated client data including name, scopes, and active status. + + Returns: + Client model with updated details. + + Raises: + NotFoundError: If no client with the given ID exists. + """ async with SessionLocal() as session: db_client = await _get_entity(session, DBClient, client_id, "Client not found") db_client.name = data.name @@ -81,11 +145,36 @@ async def update_client(client_id: uuid.UUID, data: UpdateClient) -> Client: async def delete_client(client_id: uuid.UUID) -> None: + """Delete a client by ID. + + Permanently removes a client from the database and all associated data. + + Args: + client_id: UUID of the client to delete. + + Raises: + NotFoundError: If no client with the given ID exists. + """ async with SessionLocal() as session: await _delete_entity(session, DBClient, client_id, "Client not found") async def issue_jwt_client_token(client_id: uuid.UUID, client_secret: str) -> ClientToken: + """Issue a new JWT access token for a client. + + Validates client credentials and generates a new access token + with the client's scopes and expiration time. + + Args: + client_id: UUID of the client requesting a token. + client_secret: Client's secret for authentication. + + Returns: + ClientToken containing the access token, expiration, and scopes. + + Raises: + InvalidClient: If client credentials are invalid or client is disabled. + """ async with SessionLocal() as session: db_client = await _get_entity(session, DBClient, client_id, "Client not found", raise_exception=False) if not db_client: @@ -114,6 +203,21 @@ async def issue_jwt_client_token(client_id: uuid.UUID, client_secret: str) -> Cl async def decode_jwt_client_token(access_token: str, auth_enabled: bool = True) -> DecodedClientToken: + """Decode and validate a JWT access token. + + Validates the token signature, expiration, and client status. + Ensures the token hasn't been revoked by checking token version. + + Args: + access_token: JWT access token to decode and validate. + auth_enabled: Whether authentication is enabled (for testing). + + Returns: + DecodedClientToken with client ID and scopes. + + Raises: + InvalidClient: If token is invalid, expired, or client is disabled/revoked. + """ if not auth_enabled: return DecodedClientToken(client_id=uuid.uuid7(), scopes={"*"}) diff --git a/fastpubsub/services/helpers.py b/fastpubsub/services/helpers.py index f5f45cf..2bb6c57 100644 --- a/fastpubsub/services/helpers.py +++ b/fastpubsub/services/helpers.py @@ -1,3 +1,5 @@ +"""Helper functions for service layer operations.""" + import datetime import uuid @@ -8,13 +10,32 @@ def utc_now(): + """Get current UTC timestamp. + + Returns: + Current datetime with UTC timezone. + """ return datetime.datetime.now(datetime.UTC) async def _get_entity( session, model, entity_id: str | uuid.UUID, error_message: str, raise_exception: bool = True ): - """Generic helper to get an entity by ID or raise NotFoundError.""" + """Generic helper to get an entity by ID or raise NotFoundError. + + Args: + session: Database session to use for the query. + model: SQLAlchemy model class to query. + entity_id: ID of the entity to retrieve. + error_message: Error message to include in NotFoundError. + raise_exception: Whether to raise NotFoundError if entity is not found. + + Returns: + The entity instance if found, None if not found and raise_exception is False. + + Raises: + NotFoundError: If entity is not found and raise_exception is True. + """ stmt = select(model).filter_by(id=entity_id) result = await session.execute(stmt) entity = result.scalar_one_or_none() @@ -24,7 +45,17 @@ async def _get_entity( async def _delete_entity(session, model, entity_id: str | uuid.UUID, error_message: str) -> None: - """Generic helper to delete an entity by ID or raise NotFoundError.""" + """Generic helper to delete an entity by ID or raise NotFoundError. + + Args: + session: Database session to use for the operation. + model: SQLAlchemy model class to delete from. + entity_id: ID of the entity to delete. + error_message: Error message to include in NotFoundError. + + Raises: + NotFoundError: If entity with the given ID doesn't exist. + """ entity = await _get_entity(session, model, entity_id, error_message) await session.delete(entity) await session.commit() diff --git a/fastpubsub/services/messages.py b/fastpubsub/services/messages.py index afbde4b..8b3a279 100644 --- a/fastpubsub/services/messages.py +++ b/fastpubsub/services/messages.py @@ -1,3 +1,5 @@ +"""Message operations service for publishing, consuming, and managing pub/sub messages.""" + import uuid from typing import Any @@ -10,6 +12,15 @@ async def publish_messages(topic_id: str, messages: list[dict[str, Any]]) -> int: + """Publish messages to a topic. + + Args: + topic_id: ID of the topic to publish messages to. + messages: List of message dictionaries to publish. + + Returns: + Number of messages successfully published. + """ query = "SELECT publish_messages(:topic_id, CAST(:messages AS jsonb[]))" stmt = text(query).bindparams(topic_id=topic_id, messages=messages) jsonb_array = [Json(m) for m in messages] @@ -26,6 +37,16 @@ async def publish_messages(topic_id: str, messages: list[dict[str, Any]]) -> int async def consume_messages(subscription_id: str, consumer_id: str, batch_size: int) -> list[Message]: + """Consume messages from a subscription. + + Args: + subscription_id: ID of the subscription to consume from. + consumer_id: Unique identifier for the consumer. + batch_size: Number of messages to retrieve. + + Returns: + List of available messages for consumption. + """ query = "SELECT * FROM consume_messages(:subscription_id, :consumer_id, :batch_size)" stmt = text(query) @@ -45,16 +66,44 @@ async def consume_messages(subscription_id: str, consumer_id: str, batch_size: i async def ack_messages(subscription_id: str, message_ids: list[uuid.UUID]) -> bool: + """Acknowledge successful processing of messages. + + Args: + subscription_id: ID of the subscription. + message_ids: List of message UUIDs to acknowledge. + + Returns: + True if exactly one row was affected, False otherwise. + """ query = "SELECT ack_messages(:subscription_id, :message_ids)" return await _execute_sql_command(query, {"subscription_id": subscription_id, "message_ids": message_ids}) async def nack_messages(subscription_id: str, message_ids: list[uuid.UUID]) -> bool: + """Negative acknowledgment of message processing failure. + + Args: + subscription_id: ID of the subscription. + message_ids: List of message UUIDs to negatively acknowledge. + + Returns: + True if exactly one row was affected, False otherwise. + """ query = "SELECT nack_messages(:subscription_id, :message_ids)" return await _execute_sql_command(query, {"subscription_id": subscription_id, "message_ids": message_ids}) async def list_dlq_messages(subscription_id: str, offset: int = 0, limit: int = 100) -> list[Message]: + """List messages in the dead letter queue. + + Args: + subscription_id: ID of the subscription. + offset: Number of items to skip for pagination. + limit: Maximum number of items to return. + + Returns: + List of messages in the DLQ. + """ query = "SELECT * FROM list_dlq_messages(:subscription_id, :offset, :limit)" stmt = text(query) @@ -73,21 +122,54 @@ async def list_dlq_messages(subscription_id: str, offset: int = 0, limit: int = async def reprocess_dlq_messages(subscription_id: str, message_ids: list[uuid.UUID]) -> bool: + """Move dead letter queue messages back to active processing. + + Args: + subscription_id: ID of the subscription. + message_ids: List of message UUIDs to reprocess. + + Returns: + True if exactly one row was affected, False otherwise. + """ query = "SELECT reprocess_dlq_messages(:subscription_id, :message_ids)" return await _execute_sql_command(query, {"subscription_id": subscription_id, "message_ids": message_ids}) async def cleanup_stuck_messages(lock_timeout_seconds: int) -> bool: + """Unlock messages that have been locked for too long. + + Args: + lock_timeout_seconds: Timeout threshold for stuck messages. + + Returns: + True if cleanup was successful, False otherwise. + """ query = "SELECT cleanup_stuck_messages(make_interval(secs => :timeout))" return await _execute_sql_command(query, {"timeout": lock_timeout_seconds}) async def cleanup_acked_messages(older_than_seconds: int) -> bool: + """Remove acknowledged messages older than the threshold. + + Args: + older_than_seconds: Age threshold for message cleanup. + + Returns: + True if cleanup was successful, False otherwise. + """ query = "SELECT cleanup_acked_messages(make_interval(secs => :older_than))" return await _execute_sql_command(query, {"older_than": older_than_seconds}) async def subscription_metrics(subscription_id: str) -> SubscriptionMetrics: + """Get metrics and statistics for a subscription. + + Args: + subscription_id: ID of the subscription. + + Returns: + SubscriptionMetrics containing message counts by state. + """ query = "SELECT * FROM subscription_metrics(:subscription_id)" stmt = text(query) @@ -101,6 +183,11 @@ async def subscription_metrics(subscription_id: str) -> SubscriptionMetrics: async def database_ping() -> bool: + """Check database connectivity. + + Returns: + True if database is reachable, False otherwise. + """ async with SessionLocal() as session: result = await session.scalar(select(1)) return result == 1 diff --git a/fastpubsub/services/subscriptions.py b/fastpubsub/services/subscriptions.py index 3607ad8..9bc258c 100644 --- a/fastpubsub/services/subscriptions.py +++ b/fastpubsub/services/subscriptions.py @@ -1,3 +1,5 @@ +"""Subscription management services for creating and managing topic subscriptions.""" + from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -9,6 +11,18 @@ async def create_subscription(data: CreateSubscription) -> Subscription: + """Create a new subscription to a topic. + + Args: + data: Subscription creation data including ID, topic ID, filter, and delivery settings. + + Returns: + Subscription model with the created subscription details. + + Raises: + AlreadyExistsError: If a subscription with the same ID already exists. + NotFoundError: If the specified topic doesn't exist. + """ async with SessionLocal() as session: db_subscription = DBSubscription( id=data.id, @@ -34,6 +48,17 @@ async def create_subscription(data: CreateSubscription) -> Subscription: async def get_subscription(subscription_id: str) -> Subscription: + """Retrieve a subscription by ID. + + Args: + subscription_id: ID of the subscription to retrieve. + + Returns: + Subscription model with full subscription details. + + Raises: + NotFoundError: If no subscription with the given ID exists. + """ async with SessionLocal() as session: db_subscription = await _get_entity( session, DBSubscription, subscription_id, "Subscription not found" @@ -42,6 +67,15 @@ async def get_subscription(subscription_id: str) -> Subscription: async def list_subscription(offset: int, limit: int) -> list[Subscription]: + """List subscriptions with pagination support. + + Args: + offset: Number of items to skip for pagination. + limit: Maximum number of items to return. + + Returns: + List of Subscription models. + """ async with SessionLocal() as session: stmt = select(DBSubscription).order_by(DBSubscription.id.asc()).offset(offset).limit(limit) result = await session.execute(stmt) @@ -50,5 +84,13 @@ async def list_subscription(offset: int, limit: int) -> list[Subscription]: async def delete_subscription(subscription_id: str) -> None: + """Delete a subscription by ID. + + Args: + subscription_id: ID of the subscription to delete. + + Raises: + NotFoundError: If no subscription with the given ID exists. + """ async with SessionLocal() as session: await _delete_entity(session, DBSubscription, subscription_id, "Subscription not found") diff --git a/fastpubsub/services/topics.py b/fastpubsub/services/topics.py index a645fd5..0991854 100644 --- a/fastpubsub/services/topics.py +++ b/fastpubsub/services/topics.py @@ -1,3 +1,5 @@ +"""Topic management services for creating and managing pub/sub topics.""" + from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -9,6 +11,17 @@ async def create_topic(data: CreateTopic) -> Topic: + """Create a new topic in the pub/sub system. + + Args: + data: Topic creation data including the unique topic ID. + + Returns: + Topic model with the created topic details. + + Raises: + AlreadyExistsError: If a topic with the same ID already exists. + """ async with SessionLocal() as session: db_topic = DBTopic(id=data.id, created_at=utc_now()) session.add(db_topic) @@ -24,12 +37,32 @@ async def create_topic(data: CreateTopic) -> Topic: async def get_topic(topic_id: str) -> Topic: + """Retrieve a topic by ID. + + Args: + topic_id: ID of the topic to retrieve. + + Returns: + Topic model with full topic details. + + Raises: + NotFoundError: If no topic with the given ID exists. + """ async with SessionLocal() as session: db_topic = await _get_entity(session, DBTopic, topic_id, "Topic not found") return Topic(**db_topic.to_dict()) async def list_topic(offset: int, limit: int) -> list[Topic]: + """List topics with pagination support. + + Args: + offset: Number of items to skip for pagination. + limit: Maximum number of items to return. + + Returns: + List of Topic models. + """ async with SessionLocal() as session: stmt = select(DBTopic).order_by(DBTopic.id.asc()).offset(offset).limit(limit) result = await session.execute(stmt) @@ -38,5 +71,13 @@ async def list_topic(offset: int, limit: int) -> list[Topic]: async def delete_topic(topic_id: str) -> None: + """Delete a topic by ID. + + Args: + topic_id: ID of the topic to delete. + + Raises: + NotFoundError: If no topic with the given ID exists. + """ async with SessionLocal() as session: await _delete_entity(session, DBTopic, topic_id, "Topic not found") diff --git a/tests/conftest.py b/tests/conftest.py index 7bee1dc..194aeda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +"""Test configuration and fixtures for fastpubsub application tests.""" + import pytest import pytest_asyncio from fastapi.testclient import TestClient @@ -17,7 +19,14 @@ @pytest_asyncio.fixture(scope="session") async def async_engine(): - """Create an async engine for testing.""" + """Create an async engine for testing. + + Sets up a test database with migrations applied for the test session. + Tears down the database after all tests complete. + + Yields: + Async SQLAlchemy engine configured for testing. + """ await run_migrations(command_type="upgrade", revision="head") yield engine await run_migrations(command_type="downgrade", revision="-1") @@ -26,6 +35,17 @@ async def async_engine(): @pytest_asyncio.fixture(scope="function") async def session(async_engine): + """Create a database session for each test function. + + Provides a clean database session that automatically cleans up + all test data after each test function completes. + + Args: + async_engine: The async database engine fixture. + + Yields: + Async database session for test operations. + """ async with SessionLocal() as sess: yield sess # Clean up after each test @@ -38,4 +58,9 @@ async def session(async_engine): @pytest.fixture def client(): + """Create a test client for FastAPI application. + + Returns: + TestClient configured for testing the FastAPI application. + """ return TestClient(app) diff --git a/tests/helpers.py b/tests/helpers.py index a29985c..83dccbf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,6 +1,17 @@ +"""Helper functions for test utilities.""" + import asyncio def sync_call_function(service, *args, **kwargs): - """Helper function to run async functions in sync tests.""" + """Helper function to run async functions in sync tests. + + Args: + service: Async function to call. + *args: Positional arguments to pass to the async function. + **kwargs: Keyword arguments to pass to the async function. + + Returns: + Result of the async function execution. + """ return asyncio.run(service(*args, **kwargs)) diff --git a/tests/test_config.py b/tests/test_config.py index da62e42..f2430ba 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,5 @@ +"""Tests for configuration settings validation.""" + import pytest from pydantic import ValidationError @@ -5,18 +7,33 @@ def test_settings_log_level(): + """Test that invalid log level values are rejected. + + Validates that Settings raises a ValidationError when an invalid + log level string is provided. + """ with pytest.raises(ValidationError) as excinfo: Settings(log_level="invalid") assert "Input should be 'debug', 'info', 'warning', 'error' or 'critical'" in str(excinfo.value) def test_settings_database_url_format(): + """Test that database URL format is validated. + + Validates that Settings requires the correct PostgreSQL URL format + starting with 'postgresql+psycopg://'. + """ with pytest.raises(ValidationError) as excinfo: Settings(database_url="postgresql://fastpubsub:fastpubsub@localhost:5432/fastpubsub") assert "must start with 'postgresql+psycopg://'" in str(excinfo.value) def test_settings_subscription_backoff_order(): + """Test that subscription backoff timing is validated. + + Validates that max backoff seconds must be greater than or equal + to min backoff seconds in subscription configuration. + """ with pytest.raises(ValidationError) as excinfo: Settings(subscription_backoff_min_seconds=5, subscription_backoff_max_seconds=4) assert ( diff --git a/tests/test_models.py b/tests/test_models.py index 83c16e7..a20a1ec 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,5 @@ +"""Tests for Pydantic model validation.""" + import pytest from pydantic import ValidationError @@ -23,6 +25,14 @@ ], ) def test_create_client_with_valid_scopes(scopes): + """Test that valid client scopes are accepted. + + Validates that CreateClient model accepts all valid scope strings + defined in the system. + + Args: + scopes: Valid scope string to test. + """ client = CreateClient(name="my client", scopes=scopes) assert client.scopes == scopes @@ -47,5 +57,13 @@ def test_create_client_with_valid_scopes(scopes): ], ) def test_create_client_with_invalid_scopes(scopes): + """Test that invalid client scopes are rejected. + + Validates that CreateClient model rejects invalid scope strings + that don't follow the expected naming convention. + + Args: + scopes: Invalid scope string to test. + """ with pytest.raises(ValidationError): CreateClient(name="my client", scopes=scopes)