diff --git a/python_template_server/main.py b/python_template_server/main.py index 9473cf6..6613bee 100644 --- a/python_template_server/main.py +++ b/python_template_server/main.py @@ -3,20 +3,36 @@ from pathlib import Path from typing import Any -from python_template_server.constants import CONFIG_FILE_PATH +from python_template_server.constants import CONFIG_FILE_PATH, STATIC_DIR from python_template_server.models import TemplateServerConfig +from python_template_server.routers import BaseRouter from python_template_server.template_server import TemplateServer class ExampleServer(TemplateServer): """Example server inheriting from TemplateServer.""" - def __init__(self, config_filepath: Path = CONFIG_FILE_PATH) -> None: + def __init__( + self, + config_filepath: Path = CONFIG_FILE_PATH, + config: TemplateServerConfig | None = None, + static_dir: Path = STATIC_DIR, + ) -> None: """Initialize the ExampleServer by delegating to the template server. + :param TemplateServerConfig config: Configuration object :param Path config_filepath: Configuration filepath + :param Path static_dir: Static files directory """ - super().__init__(config_filepath=config_filepath) + super().__init__(config=config, config_filepath=config_filepath, static_dir=static_dir) + + @property + def routers(self) -> list[BaseRouter]: + """Define the API routers for the server. + + :return list[BaseRouter]: List of API routers + """ + return [] def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: """Validate configuration from the config.json file. @@ -25,10 +41,6 @@ def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: """ return super().validate_config(config_data) - def setup_routes(self) -> None: - """Set up API routes.""" - pass - def run() -> None: """Serve the FastAPI application using uvicorn. diff --git a/python_template_server/routers/__init__.py b/python_template_server/routers/__init__.py new file mode 100644 index 0000000..1a9c228 --- /dev/null +++ b/python_template_server/routers/__init__.py @@ -0,0 +1,6 @@ +"""Routers for the FastAPI server.""" + +from .base_router import BaseRouter +from .template_server_router import TemplateServerRouter + +__all__ = ["BaseRouter", "TemplateServerRouter"] diff --git a/python_template_server/routers/base_router.py b/python_template_server/routers/base_router.py new file mode 100644 index 0000000..3481aaf --- /dev/null +++ b/python_template_server/routers/base_router.py @@ -0,0 +1,109 @@ +"""Base router for the FastAPI server.""" + +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable + +from fastapi import APIRouter, HTTPException, Security +from fastapi.security import APIKeyHeader +from pydantic import BaseModel +from slowapi import Limiter + +from python_template_server.authentication_handler import verify_token +from python_template_server.constants import API_KEY_HEADER_NAME +from python_template_server.models import ResponseCode + +logger = logging.getLogger(__name__) + + +API_KEY_HEADER = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) + + +class BaseRouter(ABC): + """Abstract base class for API routers.""" + + def __init__(self, prefix: str) -> None: + """Initialize the base router.""" + self.router = APIRouter(prefix=prefix) + + self.hashed_token: str = "" + self.limiter: Limiter | None + self.rate_limit: str + + @abstractmethod + def setup_routes(self) -> None: + """Abstract method to set up API routes.""" + pass + + async def _verify_api_key(self, api_key: str | None = Security(API_KEY_HEADER)) -> None: + """Verify the API key from the request header. + + :param str | None api_key: The API key from the X-API-Key header + :raise HTTPException: If the API key is missing or invalid + """ + if api_key is None: + logger.warning("Missing API key in request!") + raise HTTPException( + status_code=ResponseCode.BAD_REQUEST, + detail="Missing API key", + ) + + try: + if not verify_token(api_key, self.hashed_token): + logger.warning("Invalid API key attempt!") + raise HTTPException( + status_code=ResponseCode.UNAUTHORIZED, + detail="Invalid API key", + ) + except ValueError as e: + logger.exception("Error verifying API key!") + raise HTTPException( + status_code=ResponseCode.INTERNAL_SERVER_ERROR, + detail=str(e), + ) from e + + def configure(self, hashed_token: str, limiter: Limiter | None, rate_limit: str) -> None: + """Configure the router with shared dependencies. + + :param str hashed_token: The hashed token for API key verification + :param Limiter | None limiter: The rate limiter instance to use for this router + :param str rate_limit: The rate limit string to apply to limited routes + """ + self.hashed_token = hashed_token + self.limiter = limiter + self.rate_limit = rate_limit + + def add_route( + self, + endpoint: str, + handler_function: Callable, + response_model: type[BaseModel], + methods: list[str], + limited: bool, # noqa: FBT001 + authentication_required: bool, # noqa: FBT001 + ) -> None: + """Add an API route. + + :param str endpoint: The API endpoint path + :param Callable handler_function: The handler function for the endpoint + :param BaseModel response_model: The Pydantic model for the response + :param list[str] methods: The HTTP methods for the endpoint + :param bool limited: Whether to apply rate limiting to this route + :param bool authentication_required: Whether authentication is required for this route + """ + try: + limited_method = None + if limited and self.limiter is not None: + limited_method = self.limiter.limit(self.rate_limit)(handler_function) + + self.router.add_api_route( + path=endpoint, + endpoint=limited_method or handler_function, + methods=methods, + response_model=response_model, + dependencies=[Security(self._verify_api_key)] if authentication_required else None, + ) + except AttributeError as e: + error_msg = "Router not configured with limiter and rate limit. Call configure() before adding routes." + logger.exception(error_msg) + raise RuntimeError(error_msg) from e diff --git a/python_template_server/routers/template_server_router.py b/python_template_server/routers/template_server_router.py new file mode 100644 index 0000000..74c54b3 --- /dev/null +++ b/python_template_server/routers/template_server_router.py @@ -0,0 +1,47 @@ +"""Template server router with health and login endpoints.""" + +from fastapi import Request + +from python_template_server.models import GetHealthResponse, GetLoginResponse +from python_template_server.routers import BaseRouter + + +class TemplateServerRouter(BaseRouter): + """Router for the template server with health and login endpoints.""" + + def setup_routes(self) -> None: + """Set up the API routes for the template server.""" + self.add_route( + endpoint="/health", + handler_function=self.get_health, + response_model=GetHealthResponse, + methods=["GET"], + limited=False, + authentication_required=False, + ) + self.add_route( + endpoint="/login", + handler_function=self.get_login, + response_model=GetLoginResponse, + methods=["GET"], + limited=True, + authentication_required=True, + ) + + async def get_health(self, request: Request) -> GetHealthResponse: + """Get server health. + + :param Request request: The incoming HTTP request + :return GetHealthResponse: Health status response + :raise HTTPException: If the server token is not configured + """ + return GetHealthResponse(message="Server is healthy") + + async def get_login(self, request: Request) -> GetLoginResponse: + """Handle user login and return a success response. + + :param Request request: The incoming HTTP request + :return GetLoginResponse: Login success response + :raise HTTPException: If the server token is not configured + """ + return GetLoginResponse(message="Login successful.") diff --git a/python_template_server/template_server.py b/python_template_server/template_server.py index c1c2666..6731f25 100644 --- a/python_template_server/template_server.py +++ b/python_template_server/template_server.py @@ -5,7 +5,7 @@ import os import sys from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from importlib.metadata import metadata from pathlib import Path @@ -13,12 +13,11 @@ import dotenv import uvicorn -from fastapi import FastAPI, HTTPException, Request, Security +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, Response from fastapi.security import APIKeyHeader from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel from pydantic_core import ValidationError from slowapi import Limiter from slowapi.errors import RateLimitExceeded @@ -26,7 +25,6 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from template_python.logging_setup import add_file_handler, setup_default_logging -from python_template_server.authentication_handler import verify_token from python_template_server.certificate_handler import CertificateHandler from python_template_server.constants import ( API_KEY_HEADER_NAME, @@ -38,15 +36,15 @@ LOGGING_MAX_BYTES_MB, MB_TO_BYTES, STATIC_DIR, + TOKEN_ENV_VAR_NAME, ) from python_template_server.middleware import RequestLoggingMiddleware, SecurityHeadersMiddleware from python_template_server.models import ( CustomJSONResponse, - GetHealthResponse, - GetLoginResponse, ResponseCode, TemplateServerConfig, ) +from python_template_server.routers import BaseRouter, TemplateServerRouter dotenv.load_dotenv(ENV_FILE_PATH) setup_default_logging() @@ -58,13 +56,16 @@ logger = logging.getLogger(__name__) +TEMPLATE_SERVER_ROUTER = TemplateServerRouter(prefix="") + + class TemplateServer(ABC): """Template FastAPI server. This class provides a template for building FastAPI servers with common features such as request logging, security headers and rate limiting. - Ensure you implement the `setup_routes` and `validate_config` methods in subclasses. + Ensure you implement the `routers` property and `validate_config` method in subclasses. """ def __init__( @@ -107,7 +108,7 @@ def __init__( self.host = os.getenv("HOST", "localhost") self.port = int(os.getenv("PORT", "443")) - if not (hashed_token := os.getenv("API_TOKEN_HASH")): + if not (hashed_token := os.getenv(TOKEN_ENV_VAR_NAME)): error_msg = "Server token is not configured. Set the token using: uv run generate-new-token" logger.error(error_msg) raise HTTPException( @@ -123,6 +124,12 @@ def __init__( self._setup_routes() logger.info("Template server initialization complete.") + @staticmethod + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[None]: + """Handle application lifespan events.""" + yield + @property def static_dir_exists(self) -> bool: """Check if the static directory exists. @@ -131,33 +138,25 @@ def static_dir_exists(self) -> bool: """ return self.static_dir.exists() and (self.static_dir / "index.html").exists() - @staticmethod - @asynccontextmanager - async def lifespan(app: FastAPI) -> AsyncGenerator[None]: - """Handle application lifespan events.""" - yield + @property + @abstractmethod + def routers(self) -> list[BaseRouter]: + """List of BaseRouter instances to include in the server. + + :return list[BaseRouter]: List of BaseRouter instances + """ + return [] @abstractmethod def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: """Validate configuration data against the TemplateServerConfig model. - This method must be implemented by subclasses to validate the configuration data and return a class which - inherits from TemplateServerConfig. - :param dict config_data: The configuration data to validate :return TemplateServerConfig: The validated configuration model :raise ValidationError: If the configuration data is invalid """ return TemplateServerConfig.model_validate(config_data) - @abstractmethod - def setup_routes(self) -> None: - """Add custom API routes. - - This method must be implemented by subclasses to define API endpoints using `add_route`. - """ - pass - def load_config(self, config_filepath: Path) -> TemplateServerConfig: """Load configuration from the specified json file. @@ -185,36 +184,6 @@ def load_config(self, config_filepath: Path) -> TemplateServerConfig: else: return config - async def _verify_api_key( - self, api_key: str | None = Security(APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)) - ) -> None: - """Verify the API key from the request header. - - :param str | None api_key: The API key from the X-API-Key header - :raise HTTPException: If the API key is missing or invalid - """ - if api_key is None: - logger.warning("Missing API key in request!") - raise HTTPException( - status_code=ResponseCode.UNAUTHORIZED, - detail="Missing API key", - ) - - try: - if not verify_token(api_key, self.hashed_token): - logger.warning("Invalid API key attempt!") - raise HTTPException( - status_code=ResponseCode.UNAUTHORIZED, - detail="Invalid API key", - ) - logger.debug("API key validated successfully.") - except ValueError as e: - logger.exception("Error verifying API key!") - raise HTTPException( - status_code=ResponseCode.UNAUTHORIZED, - detail=str(e), - ) from e - def _setup_request_logging(self) -> None: """Set up request logging middleware.""" self.app.add_middleware(RequestLoggingMiddleware) @@ -293,15 +262,25 @@ def _setup_rate_limiting(self) -> None: self.config.rate_limit.storage_uri or "in-memory", ) - def _limit_route(self, route_function: Callable[..., Any]) -> Callable[..., Any]: - """Apply rate limiting to a route function if enabled. + async def _custom_404_handler(self, request: Request, exc: StarletteHTTPException) -> Response: + """Handle 404 errors by serving custom 404.html if available.""" + if exc.status_code == ResponseCode.NOT_FOUND and self.static_dir_exists: + not_found_page = self.static_dir / "404.html" + if not_found_page.is_file(): + return FileResponse(not_found_page, status_code=ResponseCode.NOT_FOUND) + raise exc - :param Callable route_function: The route handler function - :return Callable: The potentially rate-limited route handler - """ - if self.limiter is not None: - return self.limiter.limit(self.config.rate_limit.rate_limit)(route_function) # type: ignore[no-any-return] - return route_function + def _setup_routes(self) -> None: + """Set up API routes.""" + for router in [TEMPLATE_SERVER_ROUTER, *self.routers]: + router.configure(self.hashed_token, self.limiter, self.config.rate_limit.rate_limit) + router.setup_routes() + self.app.include_router(router.router) + + if self.static_dir_exists: + logger.info("Mounting static directory: %s", self.static_dir) + self.app.mount("/", StaticFiles(directory=str(self.static_dir), html=True), name="static") + self.app.add_exception_handler(StarletteHTTPException, self._custom_404_handler) # type: ignore[arg-type] def run(self) -> None: """Run the server using uvicorn.""" @@ -327,81 +306,3 @@ def run(self) -> None: except Exception: logger.exception("Failed to start!") sys.exit(1) - - def add_route( - self, - endpoint: str, - handler_function: Callable, - response_model: type[BaseModel], - methods: list[str], - limited: bool = True, # noqa: FBT001, FBT002 - authentication_required: bool = True, # noqa: FBT001, FBT002 - ) -> None: - """Add an API route. - - :param str endpoint: The API endpoint path - :param Callable handler_function: The handler function for the endpoint - :param BaseModel response_model: The Pydantic model for the response - :param list[str] methods: The HTTP methods for the endpoint - :param bool limited: Whether to apply rate limiting to this route - :param bool authentication_required: Whether authentication is required for this route - """ - self.app.add_api_route( - path=endpoint, - endpoint=self._limit_route(handler_function) if limited else handler_function, - methods=methods, - response_model=response_model, - dependencies=[Security(self._verify_api_key)] if authentication_required else None, - ) - - def _setup_routes(self) -> None: - """Set up API routes.""" - self.add_route( - endpoint="/health", - handler_function=self.get_health, - response_model=GetHealthResponse, - methods=["GET"], - limited=False, - authentication_required=False, - ) - self.add_route( - endpoint="/login", - handler_function=self.get_login, - response_model=GetLoginResponse, - methods=["GET"], - limited=True, - authentication_required=True, - ) - self.setup_routes() - if self.static_dir_exists: - logger.info("Mounting static directory: %s", self.static_dir) - self.app.mount("/", StaticFiles(directory=str(self.static_dir), html=True), name="static") - - @self.app.exception_handler(StarletteHTTPException) - async def custom_404_handler(request: Request, exc: StarletteHTTPException) -> FileResponse: - """Handle 404 errors by serving custom 404.html if available.""" - if exc.status_code == ResponseCode.NOT_FOUND and self.static_dir_exists: - not_found_page = self.static_dir / "404.html" - if not_found_page.is_file(): - return FileResponse(not_found_page, status_code=ResponseCode.NOT_FOUND) - raise exc - - async def get_health(self, request: Request) -> GetHealthResponse: - """Get server health. - - :param Request request: The incoming HTTP request - :return GetHealthResponse: Health status response - :raise HTTPException: If the server token is not configured - """ - return GetHealthResponse(message="Server is healthy") - - async def get_login(self, request: Request) -> GetLoginResponse: - """Handle user login and return a success response. - - :param Request request: The incoming HTTP request - :return GetLoginResponse: Login success response - :raise HTTPException: If the server token is not configured - """ - msg = "Login successful." - logger.info(msg) - return GetLoginResponse(message=msg) diff --git a/tests/conftest.py b/tests/conftest.py index 28eac96..75ccb67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest +from slowapi import Limiter from python_template_server.models import ( CertificateConfigModel, @@ -15,6 +16,8 @@ SecurityConfigModel, TemplateServerConfig, ) +from python_template_server.routers.template_server_router import TemplateServerRouter +from python_template_server.template_server import TEMPLATE_SERVER_ROUTER # General fixtures @@ -187,3 +190,18 @@ def mock_template_server_config( certificate=mock_certificate_config, json_response=mock_json_response_config, ) + + +# Server fixtures +@pytest.fixture(autouse=True) +def mock_template_server_router() -> TemplateServerRouter: + """Provide a TemplateServerRouter instance for testing.""" + mock_limiter = MagicMock(spec=Limiter) + mock_limiter.limit.return_value = MagicMock(return_value=MagicMock()) + TEMPLATE_SERVER_ROUTER.configure( + hashed_token="hashed_value", # noqa: S106 + limiter=mock_limiter, + rate_limit="10/minute", + ) + TEMPLATE_SERVER_ROUTER.setup_routes() + return TEMPLATE_SERVER_ROUTER diff --git a/tests/routers/test_base_router.py b/tests/routers/test_base_router.py new file mode 100644 index 0000000..f82d23a --- /dev/null +++ b/tests/routers/test_base_router.py @@ -0,0 +1,194 @@ +"""Unit tests for the python_template_server.routers.base_router module.""" + +import asyncio +from collections.abc import Generator +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException, Request +from fastapi.routing import APIRoute +from slowapi import Limiter + +from python_template_server.models import BaseResponse, ResponseCode +from python_template_server.routers import BaseRouter + +MOCK_TOKEN = "hashed_value" # noqa: S105 +MOCK_RATE_LIMIT = "10/minute" + + +class MockRouter(BaseRouter): + """Mock implementation of BaseRouter for testing.""" + + def mock_unprotected_method(self, request: Request) -> BaseResponse: + """Mock unprotected method.""" + return BaseResponse(message="unprotected endpoint") + + def mock_protected_method(self, request: Request) -> BaseResponse: + """Mock protected method.""" + return BaseResponse(message="protected endpoint") + + def mock_unlimited_unprotected_method(self, request: Request) -> BaseResponse: + """Mock unlimited unprotected method.""" + return BaseResponse(message="unlimited unprotected endpoint") + + def mock_unlimited_protected_method(self, request: Request) -> BaseResponse: + """Mock unlimited protected method.""" + return BaseResponse(message="unlimited protected endpoint") + + def setup_routes(self) -> None: + """Set up mock routes for testing.""" + mock_limiter = MagicMock(spec=Limiter) + mock_limiter.limit.return_value = MagicMock(return_value=MagicMock()) + + self.configure(hashed_token=MOCK_TOKEN, limiter=mock_limiter, rate_limit=MOCK_RATE_LIMIT) + self.add_route( + endpoint="/unauthenticated-endpoint", + handler_function=self.mock_unprotected_method, + response_model=BaseResponse, + methods=["GET"], + limited=True, + authentication_required=False, + ) + self.add_route( + endpoint="/authenticated-endpoint", + handler_function=self.mock_protected_method, + response_model=BaseResponse, + methods=["POST"], + limited=True, + authentication_required=True, + ) + self.add_route( + endpoint="/unlimited-unauthenticated-endpoint", + handler_function=self.mock_unlimited_unprotected_method, + response_model=BaseResponse, + methods=["GET"], + limited=False, + authentication_required=False, + ) + self.add_route( + endpoint="/unlimited-authenticated-endpoint", + handler_function=self.mock_unlimited_protected_method, + response_model=BaseResponse, + methods=["POST"], + limited=False, + authentication_required=True, + ) + + +@pytest.fixture +def mock_router() -> MockRouter: + """Fixture to create a mock router instance.""" + router = MockRouter(prefix="/test") + router.setup_routes() + return router + + +@pytest.fixture +def mock_verify_token() -> Generator[MagicMock]: + """Mock the verify_token function.""" + with patch("python_template_server.routers.base_router.verify_token") as mock_verify: + yield mock_verify + + +class TestBaseRouterInitialization: + """Unit tests for BaseRouter initialization.""" + + def test_base_router_initialization(self, mock_router: MockRouter) -> None: + """Test that the BaseRouter initializes with the correct prefix and default values.""" + assert mock_router.router.prefix == "/test" + assert mock_router.hashed_token == MOCK_TOKEN + + +class TestVerifyApiKey: + """Unit tests for the _verify_api_key method.""" + + def test_verify_api_key_valid(self, mock_router: BaseRouter, mock_verify_token: MagicMock) -> None: + """Test _verify_api_key with valid API key.""" + mock_verify_token.return_value = True + + result = asyncio.run(mock_router._verify_api_key("valid_key")) + assert result is None + + def test_verify_api_key_missing(self, mock_router: BaseRouter) -> None: + """Test _verify_api_key with missing API key.""" + with pytest.raises(HTTPException) as exc_info: + asyncio.run(mock_router._verify_api_key(None)) + + assert exc_info.value.status_code == ResponseCode.BAD_REQUEST + assert exc_info.value.detail == "Missing API key" + + def test_verify_api_key_invalid(self, mock_router: BaseRouter, mock_verify_token: MagicMock) -> None: + """Test _verify_api_key with invalid API key.""" + mock_verify_token.return_value = False + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(mock_router._verify_api_key("invalid_key")) + + assert exc_info.value.status_code == ResponseCode.UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + + def test_verify_api_key_value_error(self, mock_router: BaseRouter, mock_verify_token: MagicMock) -> None: + """Test _verify_api_key when verify_token raises ValueError.""" + mock_verify_token.side_effect = ValueError("No stored token hash found") + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(mock_router._verify_api_key("some_key")) + + assert exc_info.value.status_code == ResponseCode.INTERNAL_SERVER_ERROR + assert "No stored token hash found" in exc_info.value.detail + + +class TestConfigure: + """Unit tests for the configure method.""" + + def test_configure(self, mock_router: BaseRouter) -> None: + """Test that configure sets the hashed_token, limiter, and rate_limit correctly.""" + assert mock_router.hashed_token == MOCK_TOKEN + assert isinstance(mock_router.limiter, Limiter) + assert mock_router.rate_limit == MOCK_RATE_LIMIT + + +class TestAddRoutes: + """Integration tests for the routes in the mock router.""" + + def test_add_unauthenticated_route(self, mock_router: BaseRouter) -> None: + """Test add_route with authentication disabled adds routes without authentication.""" + api_routes = [route for route in mock_router.router.routes if isinstance(route, APIRoute)] + routes = [route.path for route in api_routes] + assert "/test/unauthenticated-endpoint" in routes + + # Find the specific route and verify it has no dependencies (unauthenticated) + test_route = next((route for route in api_routes if route.path == "/test/unauthenticated-endpoint"), None) + assert test_route is not None + + # Verify the route has no dependencies (unauthenticated) + assert len(test_route.dependencies) == 0 + + # Verify method and response model + assert "GET" in test_route.methods + assert test_route.response_model == BaseResponse + + def test_add_authenticated_route(self, mock_router: BaseRouter) -> None: + """Test add_route with authentication enabled adds routes with authentication.""" + api_routes = [route for route in mock_router.router.routes if isinstance(route, APIRoute)] + routes = [route.path for route in api_routes] + assert "/test/authenticated-endpoint" in routes + + # Find the specific route + test_route = next((route for route in api_routes if route.path == "/test/authenticated-endpoint"), None) + assert test_route is not None + + # Verify the route has dependencies (authentication) + assert len(test_route.dependencies) > 0 + dependency = test_route.dependencies[0] + assert dependency.dependency == mock_router._verify_api_key + + # Verify method and response model + assert "POST" in test_route.methods + assert test_route.response_model == BaseResponse + + def test_limited_parameter_with_rate_limiting_enabled(self, mock_router: BaseRouter) -> None: + """Test that limited=True applies rate limiting when limiter is enabled.""" + assert isinstance(mock_router.limiter, Limiter) + assert mock_router.limiter.limit.call_count == 2 # type: ignore[attr-defined] # noqa: PLR2004 + mock_router.limiter.limit.assert_any_call(MOCK_RATE_LIMIT) # type: ignore[attr-defined] diff --git a/tests/routers/test_template_server_router.py b/tests/routers/test_template_server_router.py new file mode 100644 index 0000000..1074789 --- /dev/null +++ b/tests/routers/test_template_server_router.py @@ -0,0 +1,55 @@ +"""Unit tests for the python_template_server.routers.template_server_router module.""" + +import asyncio +from unittest.mock import MagicMock + +import pytest +from fastapi import Request +from fastapi.routing import APIRoute + +from python_template_server.routers import TemplateServerRouter + + +class TestRoutes: + """Integration tests for the mock routes in ExampleServer.""" + + def test_setup_routes(self, mock_template_server_router: TemplateServerRouter) -> None: + """Test that routes are set up correctly.""" + api_routes = [route for route in mock_template_server_router.router.routes if isinstance(route, APIRoute)] + routes = [route.path for route in api_routes] + expected_endpoints = [ + "/health", + "/login", + ] + for endpoint in expected_endpoints: + assert endpoint in routes + + +class TestGetHealthEndpoint: + """Integration tests for the /health endpoint.""" + + @pytest.fixture + def mock_request_object(self) -> Request: + """Provide a mock Request object.""" + return MagicMock(spec=Request) + + def test_get_health(self, mock_template_server_router: TemplateServerRouter, mock_request_object: Request) -> None: + """Test the /health endpoint method.""" + response = asyncio.run(mock_template_server_router.get_health(mock_request_object)) + assert response.message == "Server is healthy" + assert isinstance(response.timestamp, str) + + +class TestGetLoginEndpoint: + """Integration tests for the /login endpoint.""" + + @pytest.fixture + def mock_request_object(self) -> Request: + """Provide a mock Request object.""" + return MagicMock(spec=Request) + + def test_get_login(self, mock_template_server_router: TemplateServerRouter, mock_request_object: Request) -> None: + """Test the /login endpoint method.""" + response = asyncio.run(mock_template_server_router.get_login(mock_request_object)) + assert response.message == "Login successful." + assert isinstance(response.timestamp, str) diff --git a/tests/test_template_server.py b/tests/test_template_server.py index c6e8d58..ee27b02 100644 --- a/tests/test_template_server.py +++ b/tests/test_template_server.py @@ -8,22 +8,20 @@ from collections.abc import Generator from importlib.metadata import PackageMetadata from pathlib import Path -from typing import Any from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.routing import APIRoute from fastapi.security import APIKeyHeader from fastapi.testclient import TestClient from slowapi.errors import RateLimitExceeded from starlette.status import HTTP_429_TOO_MANY_REQUESTS from python_template_server.constants import API_PREFIX +from python_template_server.main import ExampleServer from python_template_server.middleware import RequestLoggingMiddleware, SecurityHeadersMiddleware from python_template_server.models import ( - BaseResponse, CustomJSONResponse, ResponseCode, TemplateServerConfig, @@ -46,13 +44,6 @@ def mock_package_metadata() -> Generator[PackageMetadata]: yield mock_metadata -@pytest.fixture -def mock_verify_token() -> Generator[MagicMock]: - """Mock the verify_token function.""" - with patch("python_template_server.template_server.verify_token") as mock_verify: - yield mock_verify - - MOCK_INDEX_CONTENT = "