Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions python_template_server/template_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ def __init__(
)
self.hashed_token = hashed_token

logger.info("Setting up server features...")
self._setup_request_logging()
self._setup_security_headers()
self._setup_cors()
self._setup_rate_limiting()
self._setup_routes()
logger.info("Template server initialization complete.")
logger.info("Template server initialization complete!")

@staticmethod
@asynccontextmanager
Expand Down Expand Up @@ -169,6 +170,7 @@ def load_config(self, config_filepath: Path) -> TemplateServerConfig:
sys.exit(1)

try:
logger.info("Loading configuration from: %s", config_filepath)
config_data = json.loads(config_filepath.read_text(encoding="utf-8"))
config = self.validate_config(config_data)
config.save_to_file(config_filepath)
Expand All @@ -187,7 +189,7 @@ def load_config(self, config_filepath: Path) -> TemplateServerConfig:
def _setup_request_logging(self) -> None:
"""Set up request logging middleware."""
self.app.add_middleware(RequestLoggingMiddleware)
logger.info("Request logging enabled")
logger.info("Request logging: ENABLED")

def _setup_security_headers(self) -> None:
"""Set up security headers middleware."""
Expand All @@ -198,15 +200,15 @@ def _setup_security_headers(self) -> None:
)

logger.info(
"Security headers enabled: HSTS max-age=%s, CSP=%s",
"Security headers: ENABLED | HSTS max-age=%s, CSP=%s",
self.config.security.hsts_max_age,
self.config.security.content_security_policy,
)

def _setup_cors(self) -> None:
"""Set up CORS middleware."""
if not self.config.cors.enabled:
logger.info("CORS is disabled")
logger.info("CORS: DISABLED")
return

self.app.add_middleware(
Expand All @@ -220,7 +222,7 @@ def _setup_cors(self) -> None:
)

logger.info(
"CORS enabled: origins=%s, credentials=%s, methods=%s, headers=%s",
"CORS: ENABLED | origins=%s, credentials=%s, methods=%s, headers=%s",
self.config.cors.allow_origins,
self.config.cors.allow_credentials,
self.config.cors.allow_methods,
Expand All @@ -234,7 +236,7 @@ async def _rate_limit_exception_handler(self, request: Request, exc: RateLimitEx
:param RateLimitExceeded exc: The rate limit exceeded exception
:return JSONResponse: HTTP 429 JSON response
"""
logger.warning("Rate limit exceeded for %s", request.url.path)
logger.warning("Rate limit exceeded for: %s", request.url.path)
return CustomJSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
Expand All @@ -244,7 +246,7 @@ async def _rate_limit_exception_handler(self, request: Request, exc: RateLimitEx
def _setup_rate_limiting(self) -> None:
"""Set up rate limiting middleware."""
if not self.config.rate_limit.enabled:
logger.info("Rate limiting is disabled")
logger.info("Rate limiting: DISABLED")
self.limiter = None
return

Expand All @@ -257,29 +259,29 @@ def _setup_rate_limiting(self) -> None:
self.app.add_exception_handler(RateLimitExceeded, self._rate_limit_exception_handler) # type: ignore[arg-type]

logger.info(
"Rate limiting enabled: rate=%s, storage=%s",
"Rate limiting: ENABLED | rate=%s, storage=%s",
self.config.rate_limit.rate_limit,
self.config.rate_limit.storage_uri or "in-memory",
)

async def _custom_404_handler(self, request: Request, exc: StarletteHTTPException) -> Response:
"""Handle 404 errors by serving custom 404.html if available."""
if exc.status_code == ResponseCode.NOT_FOUND and self.static_dir_exists:
not_found_page = self.static_dir / "404.html"
if not_found_page.is_file():
if (not_found_page := self.static_dir / "404.html").is_file():
return FileResponse(not_found_page, status_code=ResponseCode.NOT_FOUND)
raise exc

def _setup_routes(self) -> None:
"""Set up API routes."""
for router in [TEMPLATE_SERVER_ROUTER, *self.routers]:
routers: list[BaseRouter] = [TEMPLATE_SERVER_ROUTER, *self.routers]
for router in routers:
router.configure(self.hashed_token, self.limiter, self.config.rate_limit.rate_limit)
router.setup_routes()
self.app.include_router(router.router)

if self.static_dir_exists:
logger.info("Mounting static directory: %s", self.static_dir)
self.app.mount("/", StaticFiles(directory=str(self.static_dir), html=True), name="static")
self.app.mount("/", StaticFiles(directory=self.static_dir, html=True), name="static")
self.app.add_exception_handler(StarletteHTTPException, self._custom_404_handler) # type: ignore[arg-type]

def run(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion tests/test_template_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def mock_template_server(
"""Provide a ExampleServer instance for testing."""
with (
patch("python_template_server.template_server.CertificateHandler", return_value=MagicMock(), autospec=True),
patch("python_template_server.template_server.TemplateServer.static_dir_exists", return_value=True),
):
mock_tmp_static_path.mkdir(parents=True, exist_ok=True)
(mock_tmp_static_path / "index.html").write_text(MOCK_INDEX_CONTENT)
Expand Down