From c1d00521b591d9e0b438db46bc060138c8a0fe70 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 11:34:54 -0600 Subject: [PATCH 01/10] feat: auth-always default, remove CLOUDPAM_AUTH_ENABLED toggle - Remove RegisterRoutes() (unprotected variant) - All routes now use RegisterProtectedRoutes() with RBAC - Add missing import routes to protected registration - BREAKING: CLOUDPAM_AUTH_ENABLED env var removed - First boot always shows setup wizard - Hardcode auth_enabled=true in /healthz response Co-Authored-By: Claude Opus 4.6 --- cmd/cloudpam/main.go | 42 +++++++------------- internal/api/analysis_handlers_test.go | 2 +- internal/api/auth_handlers_test.go | 9 +++-- internal/api/handlers_test.go | 8 ++-- internal/api/integration_test.go | 31 ++++++++++----- internal/api/recommendation_handlers_test.go | 2 +- internal/api/search_handlers_test.go | 2 +- internal/api/server.go | 40 +++---------------- internal/api/system_handlers.go | 6 +-- internal/api/test_helpers_test.go | 35 ++++++++++++++++ internal/testutil/testutil.go | 36 +++++++++++------ 11 files changed, 114 insertions(+), 99 deletions(-) create mode 100644 internal/api/test_helpers_test.go diff --git a/cmd/cloudpam/main.go b/cmd/cloudpam/main.go index a8e72de..4780c73 100644 --- a/cmd/cloudpam/main.go +++ b/cmd/cloudpam/main.go @@ -170,34 +170,22 @@ func main() { aiSrv := api.NewAIPlanningServer(srv, aiService, convStore) logger.Info("ai planning subsystem initialized") - // When CLOUDPAM_AUTH_ENABLED is set (or fresh install needs setup), use protected routes with RBAC. - // Otherwise use unprotected routes for development. - authEnabled := os.Getenv("CLOUDPAM_AUTH_ENABLED") == "true" || os.Getenv("CLOUDPAM_AUTH_ENABLED") == "1" - needsSetup := len(existingUsers) == 0 - if authEnabled || needsSetup { - srv.RegisterProtectedRoutes(keyStore, sessionStore, userStore, logger.Slog()) - authSrv := api.NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) - authSrv.RegisterProtectedAuthRoutes(logger.Slog()) - userSrv := api.NewUserServer(srv, keyStore, userStore, sessionStore, auditLogger) - userSrv.RegisterProtectedUserRoutes(logger.Slog()) - dualMW := api.DualAuthMiddleware(keyStore, sessionStore, userStore, true, logger.Slog()) - discoverySrv.RegisterProtectedDiscoveryRoutes(dualMW, logger.Slog()) - analysisSrv.RegisterProtectedAnalysisRoutes(dualMW, logger.Slog()) - recSrv.RegisterProtectedRecommendationRoutes(dualMW, logger.Slog()) - aiSrv.RegisterProtectedAIPlanningRoutes(dualMW, logger.Slog()) - logger.Info("authentication enabled (RBAC enforced)") + // Auth is always enabled — register protected routes with RBAC. + srv.RegisterProtectedRoutes(keyStore, sessionStore, userStore, logger.Slog()) + authSrv := api.NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) + authSrv.RegisterProtectedAuthRoutes(logger.Slog()) + userSrv := api.NewUserServer(srv, keyStore, userStore, sessionStore, auditLogger) + userSrv.RegisterProtectedUserRoutes(logger.Slog()) + dualMW := api.DualAuthMiddleware(keyStore, sessionStore, userStore, true, logger.Slog()) + discoverySrv.RegisterProtectedDiscoveryRoutes(dualMW, logger.Slog()) + analysisSrv.RegisterProtectedAnalysisRoutes(dualMW, logger.Slog()) + recSrv.RegisterProtectedRecommendationRoutes(dualMW, logger.Slog()) + aiSrv.RegisterProtectedAIPlanningRoutes(dualMW, logger.Slog()) + + if len(existingUsers) == 0 { + logger.Info("first-boot setup required", "hint", "visit the UI to create an admin account") } else { - srv.RegisterRoutes() - authSrv := api.NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) - authSrv.RegisterAuthRoutes() - userSrv := api.NewUserServer(srv, keyStore, userStore, sessionStore, auditLogger) - userSrv.RegisterUserRoutes() - discoverySrv.RegisterDiscoveryRoutes() - analysisSrv.RegisterAnalysisRoutes() - recSrv.RegisterRecommendationRoutes() - aiSrv.RegisterAIPlanningRoutes() - logger.Info("authentication disabled (all routes open)", - "hint", "set CLOUDPAM_AUTH_ENABLED=true to enable RBAC") + logger.Info("authentication enforced", "users", len(existingUsers)) } // Background session cleanup every 15 minutes. diff --git a/internal/api/analysis_handlers_test.go b/internal/api/analysis_handlers_test.go index c068ccb..bc77265 100644 --- a/internal/api/analysis_handlers_test.go +++ b/internal/api/analysis_handlers_test.go @@ -24,7 +24,7 @@ func setupAnalysisServer() (*stdhttp.ServeMux, *storage.MemoryStore) { Output: io.Discard, }) srv := NewServer(mux, st, logger, nil, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() analysisSvc := planning.NewAnalysisService(st) analysisSrv := NewAnalysisServer(srv, analysisSvc) analysisSrv.RegisterAnalysisRoutes() diff --git a/internal/api/auth_handlers_test.go b/internal/api/auth_handlers_test.go index 1232220..217d65a 100644 --- a/internal/api/auth_handlers_test.go +++ b/internal/api/auth_handlers_test.go @@ -29,9 +29,12 @@ func setupAuthTestServer() (*AuthServer, *auth.MemoryKeyStore, *audit.MemoryAudi keyStore := auth.NewMemoryKeyStore() - authSrv := NewAuthServer(srv, keyStore, auditLogger) - srv.RegisterRoutes() - authSrv.RegisterAuthRoutes() + sessionStore := auth.NewMemorySessionStore() + userStore := auth.NewMemoryUserStore() + + authSrv := NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) + srv.registerUnprotectedTestRoutes() + authSrv.registerUnprotectedAuthTestRoutes() return authSrv, keyStore, auditLogger } diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 083c26f..36163de 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -33,7 +33,7 @@ func setupTestServer() (*Server, *storage.MemoryStore) { Output: io.Discard, }) srv := NewServer(mux, st, logger, nil, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() return srv, st } @@ -690,7 +690,7 @@ func TestReadyzEndpointDatabaseFailure(t *testing.T) { Output: io.Discard, }) srv := NewServer(mux, st, logger, nil, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() req := httptest.NewRequest(stdhttp.MethodGet, "/readyz", nil) rr := httptest.NewRecorder() @@ -729,7 +729,7 @@ func TestMetricsEndpoint(t *testing.T) { Version: "test", }) srv := NewServer(mux, st, logger, metrics, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() // Make a request to trigger metrics recording doJSON(t, srv.mux, stdhttp.MethodPost, "/api/v1/pools", `{"name":"root","cidr":"10.0.0.0/16"}`, stdhttp.StatusCreated) @@ -776,7 +776,7 @@ func TestMetricsEndpointDisabled(t *testing.T) { }) // Create without metrics (nil) srv := NewServer(mux, st, logger, nil, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() // When metrics is nil, /metrics falls through to the SPA catch-all // which returns 200 with index.html (SPA handles client-side routing). diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go index 4ddbc4c..16d729f 100644 --- a/internal/api/integration_test.go +++ b/internal/api/integration_test.go @@ -205,6 +205,9 @@ func TestIntegration_MetricsEndpoint(t *testing.T) { components := testutil.NewTestServer(t, cfg) defer components.Cleanup() + // Create an admin API key for authenticated requests + adminKey, _ := testutil.CreateTestAPIKey(t, components.KeyStore, "metrics-admin", []string{"*"}) + client := components.HTTPClient() baseURL := components.Server.URL @@ -214,12 +217,13 @@ func TestIntegration_MetricsEndpoint(t *testing.T) { path string body string status int + auth bool }{ - {http.MethodGet, "/healthz", "", http.StatusOK}, - {http.MethodGet, "/readyz", "", http.StatusOK}, - {http.MethodPost, "/api/v1/pools", `{"name":"metrics-pool","cidr":"10.0.0.0/16"}`, http.StatusCreated}, - {http.MethodGet, "/api/v1/pools", "", http.StatusOK}, - {http.MethodPost, "/api/v1/pools", `invalid json`, http.StatusBadRequest}, + {http.MethodGet, "/healthz", "", http.StatusOK, false}, + {http.MethodGet, "/readyz", "", http.StatusOK, false}, + {http.MethodPost, "/api/v1/pools", `{"name":"metrics-pool","cidr":"10.0.0.0/16"}`, http.StatusCreated, true}, + {http.MethodGet, "/api/v1/pools", "", http.StatusOK, true}, + {http.MethodPost, "/api/v1/pools", `invalid json`, http.StatusBadRequest, true}, } for _, r := range requests { @@ -231,6 +235,9 @@ func TestIntegration_MetricsEndpoint(t *testing.T) { if r.body != "" { req.Header.Set("Content-Type", "application/json") } + if r.auth { + req.Header.Set("Authorization", "Bearer "+adminKey) + } resp, err := client.Do(req) if err != nil { t.Fatalf("request %s %s failed: %v", r.method, r.path, err) @@ -720,8 +727,8 @@ func TestIntegration_ErrorResponses(t *testing.T) { client := components.HTTPClient() baseURL := components.Server.URL - // Create a valid API key for some tests - plaintext, _ := testutil.CreateTestAPIKey(t, components.KeyStore, "Error Test Key", []string{"pools:read"}) + // Create a valid API key for some tests (needs read+write for POST tests) + plaintext, _ := testutil.CreateTestAPIKey(t, components.KeyStore, "Error Test Key", []string{"pools:read", "pools:write"}) testCases := []struct { name string @@ -908,9 +915,11 @@ func TestIntegration_APIKeyExpiration(t *testing.T) { // TestIntegration_ScopeEnforcement tests that API key scopes are properly enforced. func TestIntegration_ScopeEnforcement(t *testing.T) { - // Create a simple test server without the complex middleware for scope testing + // Create a simple test server with protected routes for scope testing store := storage.NewMemoryStore() keyStore := auth.NewMemoryKeyStore() + sessionStore := auth.NewMemorySessionStore() + userStore := auth.NewMemoryUserStore() logger := observability.NewLogger(observability.Config{ Level: "debug", Format: "json", @@ -919,10 +928,10 @@ func TestIntegration_ScopeEnforcement(t *testing.T) { mux := http.NewServeMux() srv := api.NewServer(mux, store, logger, nil, nil) - srv.RegisterRoutes() + srv.RegisterProtectedRoutes(keyStore, sessionStore, userStore, logger.Slog()) - // Wrap with auth middleware (required) - handler := api.AuthMiddleware(keyStore, true, logger.Slog())(mux) + // Protected routes already include per-route DualAuthMiddleware + var handler http.Handler = mux testServer := httptest.NewServer(handler) defer testServer.Close() diff --git a/internal/api/recommendation_handlers_test.go b/internal/api/recommendation_handlers_test.go index f29452e..635f0d0 100644 --- a/internal/api/recommendation_handlers_test.go +++ b/internal/api/recommendation_handlers_test.go @@ -17,7 +17,7 @@ func setupRecommendationServer() (*stdhttp.ServeMux, *storage.MemoryStore) { st := storage.NewMemoryStore() mux := stdhttp.NewServeMux() srv := NewServerWithSlog(mux, st, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() analysisSvc := planning.NewAnalysisService(st) recStore := storage.NewMemoryRecommendationStore(st) diff --git a/internal/api/search_handlers_test.go b/internal/api/search_handlers_test.go index 2ebcd16..cbbd5cb 100644 --- a/internal/api/search_handlers_test.go +++ b/internal/api/search_handlers_test.go @@ -15,7 +15,7 @@ func setupSearchServer(t *testing.T) (*Server, *storage.MemoryStore) { store := storage.NewMemoryStore() mux := http.NewServeMux() srv := NewServer(mux, store, nil, nil, nil) - srv.RegisterRoutes() + srv.registerUnprotectedTestRoutes() return srv, store } diff --git a/internal/api/server.go b/internal/api/server.go index 84caf30..4fb9f36 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -142,41 +142,6 @@ type statusRecorder struct { func (s *statusRecorder) WriteHeader(code int) { s.status = code; s.ResponseWriter.WriteHeader(code) } -// RegisterRoutes registers all HTTP routes without RBAC protection. -// This is for backward compatibility. Use RegisterProtectedRoutes for RBAC enforcement. -func (s *Server) RegisterRoutes() { - // localAuthEnabled stays false — no auth enforcement in dev mode - s.mux.HandleFunc("/openapi.yaml", s.handleOpenAPISpec) - s.mux.HandleFunc("/healthz", s.handleHealth) - s.mux.HandleFunc("/api/v1/auth/setup", s.handleSetup) - s.mux.HandleFunc("/readyz", s.handleReady) - // Metrics endpoint - if s.metrics != nil { - s.mux.Handle("/metrics", s.metrics.Handler()) - } - s.mux.HandleFunc("/api/v1/test-sentry", s.handleTestSentry) - s.mux.HandleFunc("/api/v1/pools", s.handlePools) - // Note: /api/v1/pools/hierarchy is handled by handlePoolsSubroutes - s.mux.HandleFunc("/api/v1/pools/", s.handlePoolsSubroutes) - s.mux.HandleFunc("/api/v1/accounts", s.handleAccounts) - s.mux.HandleFunc("/api/v1/accounts/", s.handleAccountsSubroutes) - s.mux.HandleFunc("/api/v1/blocks", s.handleBlocksList) - // Data export (CSV in ZIP) - s.mux.HandleFunc("/api/v1/export", s.handleExport) - // Data import (CSV) - s.mux.HandleFunc("/api/v1/import/accounts", s.handleImportAccounts) - s.mux.HandleFunc("/api/v1/import/pools", s.handleImportPools) - // Audit logs (unprotected access) - s.mux.HandleFunc("/api/v1/audit", s.handleAuditList) - // Schema planner API - s.mux.HandleFunc("/api/v1/schema/check", s.handleSchemaCheck) - s.mux.HandleFunc("/api/v1/schema/apply", s.handleSchemaApply) - // Search API - s.mux.HandleFunc("/api/v1/search", s.handleSearch) - // Unified React SPA (catch-all) - s.mux.Handle("/", s.handleSPA()) -} - // RegisterProtectedRoutes registers all HTTP routes with RBAC protection. // Routes are protected based on the resource and action being performed. // Public endpoints (health, metrics, static) remain unprotected. @@ -227,6 +192,11 @@ func (s *Server) RegisterProtectedRoutes(keyStore auth.KeyStore, sessionStore au s.mux.Handle("/api/v1/schema/check", dualMW(poolsReadMW(http.HandlerFunc(s.handleSchemaCheck)))) s.mux.Handle("/api/v1/schema/apply", dualMW(poolsCreateMW(http.HandlerFunc(s.handleSchemaApply)))) + // Import endpoints - require create permissions + accountsCreateMW := RequirePermissionMiddleware(auth.ResourceAccounts, auth.ActionCreate, slogger) + s.mux.Handle("POST /api/v1/import/accounts", dualMW(accountsCreateMW(http.HandlerFunc(s.handleImportAccounts)))) + s.mux.Handle("POST /api/v1/import/pools", dualMW(poolsCreateMW(http.HandlerFunc(s.handleImportPools)))) + // Search endpoint - requires pools:read s.mux.Handle("/api/v1/search", dualMW(poolsReadMW(http.HandlerFunc(s.handleSearch)))) } diff --git a/internal/api/system_handlers.go b/internal/api/system_handlers.go index 04186d1..90fbf31 100644 --- a/internal/api/system_handlers.go +++ b/internal/api/system_handlers.go @@ -33,8 +33,8 @@ func (s *Server) handleOpenAPISpec(w http.ResponseWriter, r *http.Request) { func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]any{ "status": "ok", - "auth_enabled": s.authEnabled, - "local_auth_enabled": s.localAuthEnabled, + "auth_enabled": true, + "local_auth_enabled": true, "needs_setup": s.needsSetup, }) } @@ -199,8 +199,6 @@ func (s *Server) handleSetup(w http.ResponseWriter, r *http.Request) { } s.needsSetup = false - s.authEnabled = true - s.localAuthEnabled = true s.logAudit(r.Context(), "create", "user", user.ID, user.Username, http.StatusCreated) diff --git a/internal/api/test_helpers_test.go b/internal/api/test_helpers_test.go new file mode 100644 index 0000000..1e11090 --- /dev/null +++ b/internal/api/test_helpers_test.go @@ -0,0 +1,35 @@ +package api + +// registerUnprotectedTestRoutes registers all core HTTP routes without RBAC protection. +// This is a test-only helper for internal package tests that exercise handlers directly +// without going through authentication middleware. +func (s *Server) registerUnprotectedTestRoutes() { + s.mux.HandleFunc("/openapi.yaml", s.handleOpenAPISpec) + s.mux.HandleFunc("/healthz", s.handleHealth) + s.mux.HandleFunc("/api/v1/auth/setup", s.handleSetup) + s.mux.HandleFunc("/readyz", s.handleReady) + if s.metrics != nil { + s.mux.Handle("/metrics", s.metrics.Handler()) + } + s.mux.HandleFunc("/api/v1/test-sentry", s.handleTestSentry) + s.mux.HandleFunc("/api/v1/pools", s.handlePools) + s.mux.HandleFunc("/api/v1/pools/", s.handlePoolsSubroutes) + s.mux.HandleFunc("/api/v1/accounts", s.handleAccounts) + s.mux.HandleFunc("/api/v1/accounts/", s.handleAccountsSubroutes) + s.mux.HandleFunc("/api/v1/blocks", s.handleBlocksList) + s.mux.HandleFunc("/api/v1/export", s.handleExport) + s.mux.HandleFunc("/api/v1/import/accounts", s.handleImportAccounts) + s.mux.HandleFunc("/api/v1/import/pools", s.handleImportPools) + s.mux.HandleFunc("/api/v1/audit", s.handleAuditList) + s.mux.HandleFunc("/api/v1/schema/check", s.handleSchemaCheck) + s.mux.HandleFunc("/api/v1/schema/apply", s.handleSchemaApply) + s.mux.HandleFunc("/api/v1/search", s.handleSearch) + s.mux.Handle("/", s.handleSPA()) +} + +// registerUnprotectedAuthTestRoutes registers auth API endpoints without RBAC protection. +// This is a test-only helper for internal package tests. +func (as *AuthServer) registerUnprotectedAuthTestRoutes() { + as.mux.HandleFunc("/api/v1/auth/keys", as.handleAPIKeys) + as.mux.HandleFunc("/api/v1/auth/keys/", as.handleAPIKeyByID) +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 4744cf0..270fef3 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -52,6 +52,10 @@ type TestServerComponents struct { Store *storage.MemoryStore // KeyStore is the API key store. KeyStore *auth.MemoryKeyStore + // SessionStore is the session store. + SessionStore auth.SessionStore + // UserStore is the user store. + UserStore auth.UserStore // AuditLogger is the audit logger. AuditLogger audit.AuditLogger // Metrics is the metrics collector. @@ -88,6 +92,8 @@ func NewTestServer(t *testing.T, cfg TestServerConfig) *TestServerComponents { // Create key store keyStore := auth.NewMemoryKeyStore() + sessionStore := auth.NewMemorySessionStore() + userStore := auth.NewMemoryUserStore() // Create audit logger var auditLogger audit.AuditLogger @@ -98,11 +104,15 @@ func NewTestServer(t *testing.T, cfg TestServerConfig) *TestServerComponents { // Create the base server mux := http.NewServeMux() srv := api.NewServer(mux, store, logger, metrics, auditLogger) - srv.RegisterRoutes() + srv.RegisterProtectedRoutes(keyStore, sessionStore, userStore, logger.Slog()) // Create auth server for key management endpoints - authSrv := api.NewAuthServer(srv, keyStore, auditLogger) - authSrv.RegisterAuthRoutes() + authSrv := api.NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) + authSrv.RegisterProtectedAuthRoutes(logger.Slog()) + + // Create user server for user management endpoints + userSrv := api.NewUserServer(srv, keyStore, userStore, sessionStore, auditLogger) + userSrv.RegisterProtectedUserRoutes(logger.Slog()) // Build middleware chain var handler http.Handler = mux @@ -114,9 +124,9 @@ func NewTestServer(t *testing.T, cfg TestServerConfig) *TestServerComponents { handler = api.AuditMiddleware(adapter, logger.Slog())(handler) } - // Apply auth middleware if enabled + // Apply auth middleware if enabled (DualAuthMiddleware for session + API key support) if cfg.EnableAuth { - handler = api.AuthMiddleware(keyStore, cfg.RequireAuth, logger.Slog())(handler) + handler = api.DualAuthMiddleware(keyStore, sessionStore, userStore, cfg.RequireAuth, logger.Slog())(handler) } // Apply rate limiting if enabled @@ -144,13 +154,15 @@ func NewTestServer(t *testing.T, cfg TestServerConfig) *TestServerComponents { } return &TestServerComponents{ - Server: testServer, - Store: store, - KeyStore: keyStore, - AuditLogger: auditLogger, - Metrics: metrics, - Logger: logger, - Cleanup: cleanup, + Server: testServer, + Store: store, + KeyStore: keyStore, + SessionStore: sessionStore, + UserStore: userStore, + AuditLogger: auditLogger, + Metrics: metrics, + Logger: logger, + Cleanup: cleanup, } } From 8f709b201924641d6ae9b8647c6c37e059c747b6 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:26:59 -0600 Subject: [PATCH 02/10] fix: extract audit actor from auth context instead of hardcoding anonymous Co-Authored-By: Claude Opus 4.6 --- internal/api/server.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/internal/api/server.go b/internal/api/server.go index 4fb9f36..e604968 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -117,16 +117,27 @@ func (s *Server) logAudit(ctx context.Context, action, resourceType, resourceID, if s.auditLogger == nil { return } + + actor := "anonymous" + actorType := audit.ActorTypeAnonymous + + if user := auth.UserFromContext(ctx); user != nil { + actor = user.Username + actorType = audit.ActorTypeUser + } else if key := auth.APIKeyFromContext(ctx); key != nil { + actor = key.Name + actorType = audit.ActorTypeAPIKey + } + event := &audit.AuditEvent{ - Actor: "anonymous", // Will be overwritten if auth is enabled - ActorType: audit.ActorTypeAnonymous, + Actor: actor, + ActorType: actorType, Action: action, ResourceType: resourceType, ResourceID: resourceID, ResourceName: resourceName, StatusCode: statusCode, } - // Try to get request ID from context if reqID := ctx.Value(requestIDContextKey); reqID != nil { if id, ok := reqID.(string); ok { event.RequestID = id From b22f99ed183e28e96ee49fcd4d5e708f88407b25 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:27:08 -0600 Subject: [PATCH 03/10] =?UTF-8?q?fix:=20prevent=20API=20key=20scope=20elev?= =?UTF-8?q?ation=20=E2=80=94=20callers=20cannot=20grant=20higher=20privile?= =?UTF-8?q?ges=20than=20their=20own=20role?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 --- internal/api/auth_handlers.go | 12 +++++++++ internal/api/auth_handlers_test.go | 43 ++++++++++++++++++++++++++++++ internal/auth/rbac.go | 17 ++++++++++++ 3 files changed, 72 insertions(+) diff --git a/internal/api/auth_handlers.go b/internal/api/auth_handlers.go index 83068e8..e60e3cb 100644 --- a/internal/api/auth_handlers.go +++ b/internal/api/auth_handlers.go @@ -198,6 +198,18 @@ func (as *AuthServer) createAPIKey(w http.ResponseWriter, r *http.Request) { } } + // Prevent scope elevation: callers cannot create keys with higher privileges than their own role. + // Only enforced when the caller is authenticated (callerRole != RoleNone). + callerRole := auth.GetEffectiveRole(r.Context()) + if callerRole != auth.RoleNone { + requestedRole := auth.GetRoleFromScopes(input.Scopes) + if auth.RoleLevel(requestedRole) > auth.RoleLevel(callerRole) { + as.writeErr(r.Context(), w, http.StatusForbidden, "scope elevation denied", + "requested scopes require a higher privilege level than your current role") + return + } + } + // Calculate expiration var expiresAt *time.Time if input.ExpiresInDays != nil && *input.ExpiresInDays > 0 { diff --git a/internal/api/auth_handlers_test.go b/internal/api/auth_handlers_test.go index 217d65a..a61cef8 100644 --- a/internal/api/auth_handlers_test.go +++ b/internal/api/auth_handlers_test.go @@ -408,6 +408,49 @@ func TestAudit_List_Filtering(t *testing.T) { } } +// doAuthJSONWithRole is like doAuthJSON but injects a role into the request context. +func doAuthJSONWithRole(t *testing.T, mux *stdhttp.ServeMux, method, path, body string, role auth.Role, code int) *httptest.ResponseRecorder { + t.Helper() + req := httptest.NewRequest(method, path, strings.NewReader(body)) + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + ctx := auth.ContextWithRole(req.Context(), role) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != code { + t.Fatalf("%s %s: expected code %d, got %d: %s", method, path, code, rr.Code, rr.Body.String()) + } + return rr +} + +func TestCreateAPIKey_ScopeElevation(t *testing.T) { + as, _, _ := setupAuthTestServer() + + // Operator (level 3) tries to create admin-level key (scope "*" = level 4) -> 403 + body := `{"name": "Admin Key", "scopes": ["*"]}` + rr := doAuthJSONWithRole(t, as.mux, stdhttp.MethodPost, "/api/v1/auth/keys", body, auth.RoleOperator, stdhttp.StatusForbidden) + if !strings.Contains(rr.Body.String(), "scope elevation denied") { + t.Errorf("Expected 'scope elevation denied' error, got: %s", rr.Body.String()) + } + + // Operator (level 3) creates operator-level key (scope "pools:write" = level 3) -> 201 + body = `{"name": "Operator Key", "scopes": ["pools:read"]}` + doAuthJSONWithRole(t, as.mux, stdhttp.MethodPost, "/api/v1/auth/keys", body, auth.RoleOperator, stdhttp.StatusCreated) + + // Viewer (level 2) tries to create operator-level key (scope "pools:write" = level 3) -> 403 + body = `{"name": "Writer Key", "scopes": ["pools:write"]}` + rr = doAuthJSONWithRole(t, as.mux, stdhttp.MethodPost, "/api/v1/auth/keys", body, auth.RoleViewer, stdhttp.StatusForbidden) + if !strings.Contains(rr.Body.String(), "scope elevation denied") { + t.Errorf("Expected 'scope elevation denied' error, got: %s", rr.Body.String()) + } + + // Admin (level 4) creates admin-level key (scope "*" = level 4) -> 201 + body = `{"name": "Admin Key", "scopes": ["*"]}` + doAuthJSONWithRole(t, as.mux, stdhttp.MethodPost, "/api/v1/auth/keys", body, auth.RoleAdmin, stdhttp.StatusCreated) +} + func TestAudit_MethodNotAllowed(t *testing.T) { as, _, _ := setupAuthTestServer() diff --git a/internal/auth/rbac.go b/internal/auth/rbac.go index 8a6779a..32c0929 100644 --- a/internal/auth/rbac.go +++ b/internal/auth/rbac.go @@ -246,6 +246,23 @@ func GetRoleFromScopes(scopes []string) Role { } } +// RoleLevel returns the numeric privilege level of a role for comparison. +// Higher values = more privileges. +func RoleLevel(r Role) int { + switch r { + case RoleAdmin: + return 4 + case RoleOperator: + return 3 + case RoleViewer: + return 2 + case RoleAuditor: + return 1 + default: + return 0 + } +} + // ValidRoles returns all valid role values. func ValidRoles() []Role { return []Role{RoleAdmin, RoleOperator, RoleViewer, RoleAuditor} From 9cbe5d023fd98793d061f179805624d5b19d0f60 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:35:07 -0600 Subject: [PATCH 04/10] feat: trusted proxy configuration and login rate limiting - Add TrustedProxyConfig with CIDR-based proxy validation - clientKey() now ignores X-Forwarded-For by default (secure default) - Only trust XFF when direct peer is in CLOUDPAM_TRUSTED_PROXIES - Add per-IP login rate limiting (5 attempts/min default) - Login handler wrapped with LoginRateLimitMiddleware Co-Authored-By: Claude Opus 4.6 --- cmd/cloudpam/main.go | 23 +++- internal/api/integration_test.go | 2 +- internal/api/middleware.go | 143 ++++++++++++++++---- internal/api/middleware_test.go | 219 +++++++++++++++++++++++++++++-- internal/api/user_handlers.go | 38 +++++- 5 files changed, 383 insertions(+), 42 deletions(-) diff --git a/cmd/cloudpam/main.go b/cmd/cloudpam/main.go index 4780c73..ad1c0a2 100644 --- a/cmd/cloudpam/main.go +++ b/cmd/cloudpam/main.go @@ -112,6 +112,18 @@ func main() { ) } + // Parse trusted proxies for X-Forwarded-For handling + var proxyConfig *api.TrustedProxyConfig + if proxiesEnv := os.Getenv("CLOUDPAM_TRUSTED_PROXIES"); proxiesEnv != "" { + var err error + proxyConfig, err = api.ParseTrustedProxies(proxiesEnv) + if err != nil { + logger.Error("invalid CLOUDPAM_TRUSTED_PROXIES", "error", err) + } else { + logger.Info("trusted proxies configured", "count", len(proxyConfig.CIDRs)) + } + } + mux := http.NewServeMux() auditLogger := selectAuditLogger(logger) keyStore := selectKeyStore(logger) @@ -175,7 +187,11 @@ func main() { authSrv := api.NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) authSrv.RegisterProtectedAuthRoutes(logger.Slog()) userSrv := api.NewUserServer(srv, keyStore, userStore, sessionStore, auditLogger) - userSrv.RegisterProtectedUserRoutes(logger.Slog()) + loginRL := api.LoginRateLimitMiddleware(api.LoginRateLimitConfig{ + AttemptsPerMinute: 5, + ProxyConfig: proxyConfig, + }) + userSrv.RegisterProtectedUserRoutes(logger.Slog(), api.WithLoginRateLimit(loginRL)) dualMW := api.DualAuthMiddleware(keyStore, sessionStore, userStore, true, logger.Slog()) discoverySrv.RegisterProtectedDiscoveryRoutes(dualMW, logger.Slog()) analysisSrv.RegisterProtectedAnalysisRoutes(dualMW, logger.Slog()) @@ -283,6 +299,11 @@ func bootstrapAdmin(logger observability.Logger, userStore auth.UserStore, usern return } + if err := auth.ValidatePassword(password, 0); err != nil { + logger.Error("bootstrap admin password does not meet requirements", "error", err) + return + } + hash, err := auth.HashPassword(password) if err != nil { logger.Error("failed to hash admin password", "error", err) diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go index 16d729f..bfd3e3d 100644 --- a/internal/api/integration_test.go +++ b/internal/api/integration_test.go @@ -753,7 +753,7 @@ func TestIntegration_ErrorResponses(t *testing.T) { path: "/api/v1/pools", apiKey: "invalid-key", expectedStatus: http.StatusUnauthorized, - expectedError: "unauthorized", + expectedError: "invalid bearer token", }, { name: "not found", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index fda7a3c..88085cb 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -7,6 +7,7 @@ import ( "math" "net" "net/http" + "net/netip" "os" "strconv" "strings" @@ -281,19 +282,123 @@ func RateLimitMiddleware(cfg RateLimitConfig, logger *slog.Logger) Middleware { } func clientKey(r *http.Request) string { - if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { - if idx := strings.Index(xff, ","); idx != -1 { - xff = xff[:idx] + return clientKeyWithProxies(r, nil) +} + +// TrustedProxyConfig holds trusted proxy CIDR list for X-Forwarded-For handling. +type TrustedProxyConfig struct { + CIDRs []netip.Prefix +} + +// ParseTrustedProxies parses a comma-separated list of CIDRs. +func ParseTrustedProxies(raw string) (*TrustedProxyConfig, error) { + if raw == "" { + return &TrustedProxyConfig{}, nil + } + var cidrs []netip.Prefix + for _, s := range strings.Split(raw, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + prefix, err := netip.ParsePrefix(s) + if err != nil { + return nil, fmt.Errorf("invalid trusted proxy CIDR %q: %w", s, err) } - if ip := strings.TrimSpace(xff); ip != "" { - return ip + cidrs = append(cidrs, prefix) + } + return &TrustedProxyConfig{CIDRs: cidrs}, nil +} + +// IsTrusted checks if the remote address is from a trusted proxy. +func (tc *TrustedProxyConfig) IsTrusted(remoteAddr string) bool { + if tc == nil || len(tc.CIDRs) == 0 { + return false + } + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return false + } + addr, err := netip.ParseAddr(host) + if err != nil { + return false + } + for _, cidr := range tc.CIDRs { + if cidr.Contains(addr) { + return true + } + } + return false +} + +// clientKeyWithProxies extracts the client IP, only trusting X-Forwarded-For from trusted proxies. +func clientKeyWithProxies(r *http.Request, proxies *TrustedProxyConfig) string { + if proxies != nil && proxies.IsTrusted(r.RemoteAddr) { + if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { + parts := strings.SplitN(xff, ",", 2) + if ip := strings.TrimSpace(parts[0]); ip != "" { + return ip + } } } host, _, err := net.SplitHostPort(r.RemoteAddr) - if err == nil && host != "" { - return host + if err != nil { + return r.RemoteAddr + } + return host +} + +// LoginRateLimitConfig configures per-IP login rate limiting. +type LoginRateLimitConfig struct { + AttemptsPerMinute int + ProxyConfig *TrustedProxyConfig +} + +// LoginRateLimitMiddleware wraps a handler with per-IP login rate limiting. +func LoginRateLimitMiddleware(cfg LoginRateLimitConfig) func(http.Handler) http.Handler { + type ipEntry struct { + limiter *rate.Limiter + lastSeen time.Time + } + var mu sync.Mutex + clients := make(map[string]*ipEntry) + + go func() { + for { + time.Sleep(5 * time.Minute) + mu.Lock() + for ip, entry := range clients { + if time.Since(entry.lastSeen) > 10*time.Minute { + delete(clients, ip) + } + } + mu.Unlock() + } + }() + + rps := rate.Limit(float64(cfg.AttemptsPerMinute) / 60.0) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := clientKeyWithProxies(r, cfg.ProxyConfig) + + mu.Lock() + entry, ok := clients[ip] + if !ok { + entry = &ipEntry{limiter: rate.NewLimiter(rps, cfg.AttemptsPerMinute)} + clients[ip] = entry + } + entry.lastSeen = time.Now() + mu.Unlock() + + if !entry.limiter.Allow() { + w.Header().Set("Retry-After", "60") + writeJSON(w, http.StatusTooManyRequests, apiError{Error: "too many login attempts", Detail: "try again later"}) + return + } + next.ServeHTTP(w, r) + }) } - return r.RemoteAddr } // AuthMiddleware validates API key authentication. @@ -414,8 +519,8 @@ func AuthMiddleware(keyStore auth.KeyStore, required bool, logger *slog.Logger) // It tries these strategies in order: // 1. "session" cookie -> session lookup // 2. Authorization: Bearer with "cpam_" prefix -> API key (existing flow) -// 3. Authorization: Bearer without prefix -> session token lookup // If required is true, unauthenticated requests get 401. +// Non-cpam_ Bearer tokens are rejected (session IDs must not appear in headers). func DualAuthMiddleware( keyStore auth.KeyStore, sessionStore auth.SessionStore, @@ -446,7 +551,7 @@ func DualAuthMiddleware( } } - // Strategy 2 & 3: Check Authorization header + // Strategy 2: Check Authorization header authHeader := r.Header.Get("Authorization") if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") { token := strings.TrimPrefix(authHeader, "Bearer ") @@ -510,22 +615,10 @@ func DualAuthMiddleware( return } - // Strategy 3: Bearer token as session token - session, err := sessionStore.Get(ctx, token) - if err == nil && session != nil && session.IsValid() { - ctx = auth.ContextWithSession(ctx, session) - ctx = auth.ContextWithRole(ctx, session.Role) - if user, _ := userStore.GetByID(ctx, session.UserID); user != nil { - ctx = auth.ContextWithUser(ctx, user) - } - r = r.WithContext(ctx) - next.ServeHTTP(w, r) - return - } - + // Unrecognized Bearer token format - reject if required { - logAuthFailure(logger, r, "invalid session token") - writeJSON(w, http.StatusUnauthorized, apiError{Error: "unauthorized", Detail: "invalid or expired session"}) + logAuthFailure(logger, r, "invalid bearer token format") + writeJSON(w, http.StatusUnauthorized, apiError{Error: "invalid bearer token", Detail: "bearer tokens must be API keys (cpam_ prefix)"}) return } } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 322a082..2cad8cb 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -237,7 +237,7 @@ func TestRateLimitMiddlewarePerIPTracking(t *testing.T) { } } -func TestRateLimitMiddlewareXForwardedFor(t *testing.T) { +func TestRateLimitMiddlewareXForwardedForIgnoredWithoutTrustedProxy(t *testing.T) { cfg := RateLimitConfig{ RequestsPerSecond: 5, Burst: 1, @@ -246,24 +246,26 @@ func TestRateLimitMiddlewareXForwardedFor(t *testing.T) { w.WriteHeader(http.StatusOK) })) - // First request with X-Forwarded-For - should succeed + // Without trusted proxies, XFF is ignored and RemoteAddr is used for rate limiting. + // Two requests from the same RemoteAddr but different XFF should still be rate limited + // because XFF is not trusted. rr1 := httptest.NewRecorder() req1 := httptest.NewRequest(http.MethodGet, "/", nil) - req1.RemoteAddr = "proxy:80" - req1.Header.Set("X-Forwarded-For", "10.0.0.1, 10.0.0.2") + req1.RemoteAddr = "192.168.1.1:80" + req1.Header.Set("X-Forwarded-For", "10.0.0.1") handler.ServeHTTP(rr1, req1) if rr1.Code != http.StatusOK { t.Fatalf("expected first request to succeed, got %d", rr1.Code) } - // Second request with same X-Forwarded-For first IP - should be rate limited + // Same RemoteAddr, different XFF — should be rate limited because XFF is ignored rr2 := httptest.NewRecorder() req2 := httptest.NewRequest(http.MethodGet, "/", nil) - req2.RemoteAddr = "proxy:80" - req2.Header.Set("X-Forwarded-For", "10.0.0.1") + req2.RemoteAddr = "192.168.1.1:80" + req2.Header.Set("X-Forwarded-For", "10.0.0.2") handler.ServeHTTP(rr2, req2) if rr2.Code != http.StatusTooManyRequests { - t.Fatalf("expected second request to be rate limited, got %d", rr2.Code) + t.Fatalf("expected second request to be rate limited (XFF ignored), got %d", rr2.Code) } } @@ -1614,3 +1616,204 @@ func TestRequirePermissionMiddleware_NilLoggerUsesDefault(t *testing.T) { t.Error("handler should have been called") } } + +// Trusted proxy tests + +func TestParseTrustedProxies_Valid(t *testing.T) { + cfg, err := ParseTrustedProxies("10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.CIDRs) != 3 { + t.Fatalf("expected 3 CIDRs, got %d", len(cfg.CIDRs)) + } +} + +func TestParseTrustedProxies_Empty(t *testing.T) { + cfg, err := ParseTrustedProxies("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.CIDRs) != 0 { + t.Fatalf("expected 0 CIDRs, got %d", len(cfg.CIDRs)) + } +} + +func TestParseTrustedProxies_Invalid(t *testing.T) { + _, err := ParseTrustedProxies("not-a-cidr") + if err == nil { + t.Fatal("expected error for invalid CIDR") + } +} + +func TestClientKeyWithProxies_NoTrustedProxies(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:12345" + req.Header.Set("X-Forwarded-For", "10.0.0.1") + + // With nil proxies, should use RemoteAddr and ignore XFF + got := clientKeyWithProxies(req, nil) + if got != "1.2.3.4" { + t.Errorf("expected '1.2.3.4', got %q", got) + } +} + +func TestClientKeyWithProxies_TrustedProxy(t *testing.T) { + cfg, _ := ParseTrustedProxies("172.16.0.0/12") + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "172.16.0.1:8080" + req.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1") + + got := clientKeyWithProxies(req, cfg) + if got != "10.0.0.1" { + t.Errorf("expected '10.0.0.1', got %q", got) + } +} + +func TestClientKeyWithProxies_UntrustedProxy(t *testing.T) { + cfg, _ := ParseTrustedProxies("172.16.0.0/12") + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "5.6.7.8:8080" + req.Header.Set("X-Forwarded-For", "10.0.0.1") + + // 5.6.7.8 is not in 172.16.0.0/12, so XFF should be ignored + got := clientKeyWithProxies(req, cfg) + if got != "5.6.7.8" { + t.Errorf("expected '5.6.7.8', got %q", got) + } +} + +func TestClientKeyWithProxies_EmptyXFF(t *testing.T) { + cfg, _ := ParseTrustedProxies("172.16.0.0/12") + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "172.16.0.1:8080" + // No XFF header + + got := clientKeyWithProxies(req, cfg) + if got != "172.16.0.1" { + t.Errorf("expected '172.16.0.1', got %q", got) + } +} + +func TestIsTrusted_NilConfig(t *testing.T) { + var cfg *TrustedProxyConfig + if cfg.IsTrusted("1.2.3.4:80") { + t.Error("nil config should not trust anything") + } +} + +func TestIsTrusted_EmptyCIDRs(t *testing.T) { + cfg := &TrustedProxyConfig{} + if cfg.IsTrusted("1.2.3.4:80") { + t.Error("empty CIDRs should not trust anything") + } +} + +func TestIsTrusted_BadRemoteAddr(t *testing.T) { + cfg, _ := ParseTrustedProxies("10.0.0.0/8") + if cfg.IsTrusted("not-an-address") { + t.Error("bad remote addr should not be trusted") + } +} + +// Login rate limiting tests + +func TestLoginRateLimitMiddleware_AllowsBelowLimit(t *testing.T) { + loginRL := LoginRateLimitMiddleware(LoginRateLimitConfig{ + AttemptsPerMinute: 5, + }) + + handler := loginRL(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 5; i++ { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code) + } + } +} + +func TestLoginRateLimitMiddleware_BlocksOverLimit(t *testing.T) { + loginRL := LoginRateLimitMiddleware(LoginRateLimitConfig{ + AttemptsPerMinute: 5, + }) + + handler := loginRL(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust 5 allowed attempts + for i := 0; i < 5; i++ { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + handler.ServeHTTP(rr, req) + } + + // 6th should be blocked + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", rr.Code) + } + + var resp apiError + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp.Error != "too many login attempts" { + t.Errorf("expected 'too many login attempts', got %q", resp.Error) + } + + retryAfter := rr.Header().Get("Retry-After") + if retryAfter != "60" { + t.Errorf("expected Retry-After '60', got %q", retryAfter) + } +} + +func TestLoginRateLimitMiddleware_PerIPIsolation(t *testing.T) { + loginRL := LoginRateLimitMiddleware(LoginRateLimitConfig{ + AttemptsPerMinute: 2, + }) + + handler := loginRL(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Exhaust IP1 + for i := 0; i < 2; i++ { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req.RemoteAddr = "1.2.3.4:12345" + handler.ServeHTTP(rr, req) + } + + // IP1 should be blocked + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req1.RemoteAddr = "1.2.3.4:12345" + handler.ServeHTTP(rr1, req1) + if rr1.Code != http.StatusTooManyRequests { + t.Fatalf("expected IP1 to be blocked, got %d", rr1.Code) + } + + // IP2 should still work + rr2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req2.RemoteAddr = "5.6.7.8:54321" + handler.ServeHTTP(rr2, req2) + if rr2.Code != http.StatusOK { + t.Fatalf("expected IP2 to succeed, got %d", rr2.Code) + } +} diff --git a/internal/api/user_handlers.go b/internal/api/user_handlers.go index 776590c..09e2656 100644 --- a/internal/api/user_handlers.go +++ b/internal/api/user_handlers.go @@ -34,6 +34,21 @@ func NewUserServer(s *Server, keyStore auth.KeyStore, userStore auth.UserStore, } } +// userRouteConfig holds optional configuration for user route registration. +type userRouteConfig struct { + loginRateLimit func(http.Handler) http.Handler +} + +// UserRouteOption configures user route registration. +type UserRouteOption func(*userRouteConfig) + +// WithLoginRateLimit sets login rate limiting middleware. +func WithLoginRateLimit(mw func(http.Handler) http.Handler) UserRouteOption { + return func(cfg *userRouteConfig) { + cfg.loginRateLimit = mw + } +} + // RegisterUserRoutes registers user auth routes without RBAC (development mode). func (us *UserServer) RegisterUserRoutes() { us.mux.HandleFunc("/api/v1/auth/login", us.handleLogin) @@ -44,13 +59,22 @@ func (us *UserServer) RegisterUserRoutes() { } // RegisterProtectedUserRoutes registers user auth routes with RBAC. -func (us *UserServer) RegisterProtectedUserRoutes(logger *slog.Logger) { +func (us *UserServer) RegisterProtectedUserRoutes(logger *slog.Logger, opts ...UserRouteOption) { if logger == nil { logger = slog.Default() } - // Login is always public (no auth required). - us.mux.HandleFunc("/api/v1/auth/login", us.handleLogin) + var cfg userRouteConfig + for _, o := range opts { + o(&cfg) + } + + // Login is always public (no auth required), but rate limited. + loginHandler := http.Handler(http.HandlerFunc(us.handleLogin)) + if cfg.loginRateLimit != nil { + loginHandler = cfg.loginRateLimit(loginHandler) + } + us.mux.Handle("/api/v1/auth/login", loginHandler) // Dual auth middleware (session or API key). dualMW := DualAuthMiddleware(us.keyStore, us.sessionStore, us.userStore, true, logger) @@ -391,8 +415,8 @@ func (us *UserServer) createUser(w http.ResponseWriter, r *http.Request) { us.writeErr(ctx, w, http.StatusBadRequest, "password is required", "") return } - if len(input.Password) < 8 { - us.writeErr(ctx, w, http.StatusBadRequest, "password too short", "minimum 8 characters") + if err := auth.ValidatePassword(input.Password, 0); err != nil { + us.writeErr(ctx, w, http.StatusBadRequest, "password too weak", err.Error()) return } @@ -583,8 +607,8 @@ func (us *UserServer) changePassword(w http.ResponseWriter, r *http.Request, id us.writeErr(ctx, w, http.StatusBadRequest, "new_password is required", "") return } - if len(input.NewPassword) < 8 { - us.writeErr(ctx, w, http.StatusBadRequest, "password too short", "minimum 8 characters") + if err := auth.ValidatePassword(input.NewPassword, 0); err != nil { + us.writeErr(ctx, w, http.StatusBadRequest, "password too weak", err.Error()) return } From 0466159b6267814caece83424a2e49f9a73987cc Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:35:17 -0600 Subject: [PATCH 05/10] security: remove bearer-as-session-token, harden password policy - Remove Strategy 3 (Bearer token as session ID) from DualAuthMiddleware - Sessions use cookies only; API keys use Bearer tokens (clean separation) - Password minimum increased to 12 chars (NIST 800-63B) - Password maximum enforced at 72 chars (bcrypt truncation boundary) - ValidatePassword() used consistently in setup, user creation, and password change Co-Authored-By: Claude Opus 4.6 --- internal/api/system_handlers.go | 4 +- internal/auth/password.go | 30 ++++++++++++- internal/auth/password_test.go | 79 +++++++++++++++++++++++++++++++++ ui/src/pages/SetupPage.tsx | 6 +-- 4 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 internal/auth/password_test.go diff --git a/internal/api/system_handlers.go b/internal/api/system_handlers.go index 90fbf31..64d9745 100644 --- a/internal/api/system_handlers.go +++ b/internal/api/system_handlers.go @@ -156,8 +156,8 @@ func (s *Server) handleSetup(w http.ResponseWriter, r *http.Request) { s.writeErr(r.Context(), w, http.StatusBadRequest, "username is required", "") return } - if len(req.Password) < 8 { - s.writeErr(r.Context(), w, http.StatusBadRequest, "password must be at least 8 characters", "") + if err := auth.ValidatePassword(req.Password, 0); err != nil { + s.writeErr(r.Context(), w, http.StatusBadRequest, "password too weak", err.Error()) return } diff --git a/internal/auth/password.go b/internal/auth/password.go index 045d852..918a0e3 100644 --- a/internal/auth/password.go +++ b/internal/auth/password.go @@ -1,8 +1,34 @@ package auth -import "golang.org/x/crypto/bcrypt" +import ( + "fmt" -const bcryptCost = 12 + "golang.org/x/crypto/bcrypt" +) + +const ( + bcryptCost = 12 + + // DefaultMinPasswordLength is the minimum password length enforced by default. + DefaultMinPasswordLength = 12 + + // MaxPasswordLength is the maximum password length (bcrypt truncation boundary). + MaxPasswordLength = 72 +) + +// ValidatePassword checks password meets policy requirements. +func ValidatePassword(password string, minLength int) error { + if minLength <= 0 { + minLength = DefaultMinPasswordLength + } + if len(password) < minLength { + return fmt.Errorf("password must be at least %d characters", minLength) + } + if len(password) > MaxPasswordLength { + return fmt.Errorf("password must be at most %d characters", MaxPasswordLength) + } + return nil +} // HashPassword hashes a plaintext password using bcrypt. func HashPassword(password string) ([]byte, error) { diff --git a/internal/auth/password_test.go b/internal/auth/password_test.go new file mode 100644 index 0000000..86adbbd --- /dev/null +++ b/internal/auth/password_test.go @@ -0,0 +1,79 @@ +package auth + +import ( + "strings" + "testing" +) + +func TestValidatePassword(t *testing.T) { + tests := []struct { + name string + password string + minLength int + wantErr bool + errMsg string + }{ + { + name: "too short (11 chars)", + password: "abcdefghijk", + minLength: 0, + wantErr: true, + errMsg: "at least 12", + }, + { + name: "exactly 12 chars", + password: "abcdefghijkl", + minLength: 0, + wantErr: false, + }, + { + name: "at max (72 chars)", + password: strings.Repeat("a", 72), + minLength: 0, + wantErr: false, + }, + { + name: "over max (73 chars)", + password: strings.Repeat("a", 73), + minLength: 0, + wantErr: true, + errMsg: "at most 72", + }, + { + name: "empty password", + password: "", + minLength: 0, + wantErr: true, + errMsg: "at least 12", + }, + { + name: "custom min length (8) with 8 chars", + password: "abcdefgh", + minLength: 8, + wantErr: false, + }, + { + name: "custom min length (8) with 7 chars", + password: "abcdefg", + minLength: 8, + wantErr: true, + errMsg: "at least 8", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidatePassword(tt.password, tt.minLength) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("expected error to contain %q, got %q", tt.errMsg, err.Error()) + } + } else if err != nil { + t.Errorf("expected nil error, got %v", err) + } + }) + } +} diff --git a/ui/src/pages/SetupPage.tsx b/ui/src/pages/SetupPage.tsx index ca4c397..ef8ae4d 100644 --- a/ui/src/pages/SetupPage.tsx +++ b/ui/src/pages/SetupPage.tsx @@ -31,8 +31,8 @@ export default function SetupPage() { return } - if (password.length < 8) { - setError('Password must be at least 8 characters') + if (password.length < 12) { + setError('Password must be at least 12 characters') return } @@ -125,7 +125,7 @@ export default function SetupPage() { type={showPassword ? 'text' : 'password'} value={password} onChange={e => setPassword(e.target.value)} - placeholder="Minimum 8 characters" + placeholder="Minimum 12 characters" className="w-full pl-3 pr-10 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500" autoComplete="new-password" /> From c580135fd3ff1697edafd0bfaa235df38d1a7af5 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:42:02 -0600 Subject: [PATCH 06/10] feat: security settings table, store interface, and API endpoint - SecuritySettings domain type with safe defaults - SettingsStore interface with memory and SQLite implementations - Migration 0016: settings table (key-value with JSON) - GET/PATCH /api/v1/settings/security with RBAC (admin only) - Full input validation on PATCH with bounds checking - ResourceSettings added to RBAC permission model Co-Authored-By: Claude Opus 4.6 --- cmd/cloudpam/main.go | 7 ++ internal/api/settings_handlers.go | 90 ++++++++++++++ internal/api/settings_handlers_test.go | 162 +++++++++++++++++++++++++ internal/auth/rbac.go | 4 + internal/domain/settings.go | 25 ++++ internal/storage/settings.go | 13 ++ internal/storage/settings_memory.go | 34 ++++++ internal/storage/sqlite/settings.go | 42 +++++++ migrations/0016_settings.sql | 6 + 9 files changed, 383 insertions(+) create mode 100644 internal/api/settings_handlers.go create mode 100644 internal/api/settings_handlers_test.go create mode 100644 internal/domain/settings.go create mode 100644 internal/storage/settings.go create mode 100644 internal/storage/settings_memory.go create mode 100644 internal/storage/sqlite/settings.go create mode 100644 migrations/0016_settings.sql diff --git a/cmd/cloudpam/main.go b/cmd/cloudpam/main.go index ad1c0a2..13f524d 100644 --- a/cmd/cloudpam/main.go +++ b/cmd/cloudpam/main.go @@ -18,6 +18,7 @@ import ( awscollector "cloudpam/internal/discovery/aws" "cloudpam/internal/api" "cloudpam/internal/observability" + "cloudpam/internal/storage" "cloudpam/internal/planning" "cloudpam/internal/planning/llm" @@ -182,6 +183,11 @@ func main() { aiSrv := api.NewAIPlanningServer(srv, aiService, convStore) logger.Info("ai planning subsystem initialized") + // Initialize settings subsystem + settingsStore := storage.NewMemorySettingsStore() + settingsSrv := api.NewSettingsServer(srv, settingsStore) + logger.Info("settings subsystem initialized") + // Auth is always enabled — register protected routes with RBAC. srv.RegisterProtectedRoutes(keyStore, sessionStore, userStore, logger.Slog()) authSrv := api.NewAuthServerWithStores(srv, keyStore, sessionStore, userStore, auditLogger) @@ -197,6 +203,7 @@ func main() { analysisSrv.RegisterProtectedAnalysisRoutes(dualMW, logger.Slog()) recSrv.RegisterProtectedRecommendationRoutes(dualMW, logger.Slog()) aiSrv.RegisterProtectedAIPlanningRoutes(dualMW, logger.Slog()) + settingsSrv.RegisterProtectedSettingsRoutes(dualMW, logger.Slog()) if len(existingUsers) == 0 { logger.Info("first-boot setup required", "hint", "visit the UI to create an admin account") diff --git a/internal/api/settings_handlers.go b/internal/api/settings_handlers.go new file mode 100644 index 0000000..f1bc1ff --- /dev/null +++ b/internal/api/settings_handlers.go @@ -0,0 +1,90 @@ +package api + +import ( + "encoding/json" + "log/slog" + "net/http" + + "cloudpam/internal/auth" + "cloudpam/internal/domain" + "cloudpam/internal/storage" +) + +// SettingsServer handles settings API endpoints. +type SettingsServer struct { + *Server + settingsStore storage.SettingsStore +} + +// NewSettingsServer creates a new SettingsServer. +func NewSettingsServer(srv *Server, store storage.SettingsStore) *SettingsServer { + return &SettingsServer{Server: srv, settingsStore: store} +} + +// RegisterProtectedSettingsRoutes registers settings endpoints with RBAC. +func (ss *SettingsServer) RegisterProtectedSettingsRoutes(dualMW func(http.Handler) http.Handler, slogger *slog.Logger) { + adminRead := RequirePermissionMiddleware(auth.ResourceSettings, auth.ActionRead, slogger) + adminWrite := RequirePermissionMiddleware(auth.ResourceSettings, auth.ActionWrite, slogger) + + ss.mux.Handle("GET /api/v1/settings/security", + dualMW(adminRead(http.HandlerFunc(ss.handleGetSecuritySettings)))) + ss.mux.Handle("PATCH /api/v1/settings/security", + dualMW(adminWrite(http.HandlerFunc(ss.handleUpdateSecuritySettings)))) +} + +// RegisterSettingsRoutes registers settings endpoints without RBAC (for tests). +func (ss *SettingsServer) RegisterSettingsRoutes() { + ss.mux.HandleFunc("GET /api/v1/settings/security", ss.handleGetSecuritySettings) + ss.mux.HandleFunc("PATCH /api/v1/settings/security", ss.handleUpdateSecuritySettings) +} + +func (ss *SettingsServer) handleGetSecuritySettings(w http.ResponseWriter, r *http.Request) { + settings, err := ss.settingsStore.GetSecuritySettings(r.Context()) + if err != nil { + ss.writeErr(r.Context(), w, http.StatusInternalServerError, "failed to load settings", err.Error()) + return + } + writeJSON(w, http.StatusOK, settings) +} + +func (ss *SettingsServer) handleUpdateSecuritySettings(w http.ResponseWriter, r *http.Request) { + var input domain.SecuritySettings + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid request body", err.Error()) + return + } + + // Validate bounds + if input.SessionDurationHours < 1 || input.SessionDurationHours > 720 { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid session_duration_hours", "must be between 1 and 720") + return + } + if input.MaxSessionsPerUser < 1 || input.MaxSessionsPerUser > 100 { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid max_sessions_per_user", "must be between 1 and 100") + return + } + if input.PasswordMinLength < 8 || input.PasswordMinLength > 72 { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid password_min_length", "must be between 8 and 72") + return + } + if input.PasswordMaxLength < input.PasswordMinLength || input.PasswordMaxLength > 72 { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid password_max_length", "must be between min_length and 72") + return + } + if input.LoginRateLimitPerMin < 1 || input.LoginRateLimitPerMin > 60 { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid login_rate_limit_per_minute", "must be between 1 and 60") + return + } + if input.AccountLockoutAttempts < 0 || input.AccountLockoutAttempts > 100 { + ss.writeErr(r.Context(), w, http.StatusBadRequest, "invalid account_lockout_attempts", "must be between 0 and 100") + return + } + + if err := ss.settingsStore.UpdateSecuritySettings(r.Context(), &input); err != nil { + ss.writeErr(r.Context(), w, http.StatusInternalServerError, "failed to save settings", err.Error()) + return + } + + ss.logAudit(r.Context(), "update", "settings", "security", "security_settings", http.StatusOK) + writeJSON(w, http.StatusOK, input) +} diff --git a/internal/api/settings_handlers_test.go b/internal/api/settings_handlers_test.go new file mode 100644 index 0000000..368098b --- /dev/null +++ b/internal/api/settings_handlers_test.go @@ -0,0 +1,162 @@ +package api + +import ( + "encoding/json" + stdhttp "net/http" + "net/http/httptest" + "strings" + "testing" + + "cloudpam/internal/domain" + "cloudpam/internal/storage" +) + +func setupSettingsServer() *stdhttp.ServeMux { + st := storage.NewMemoryStore() + mux := stdhttp.NewServeMux() + srv := NewServerWithSlog(mux, st, nil) + + settingsStore := storage.NewMemorySettingsStore() + settingsSrv := NewSettingsServer(srv, settingsStore) + settingsSrv.RegisterSettingsRoutes() + + return mux +} + +func TestSettingsHandler_GetDefaults(t *testing.T) { + mux := setupSettingsServer() + + req := httptest.NewRequest(stdhttp.MethodGet, "/api/v1/settings/security", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusOK { + t.Fatalf("expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + var settings domain.SecuritySettings + if err := json.NewDecoder(rr.Body).Decode(&settings); err != nil { + t.Fatalf("decode error: %v", err) + } + + defaults := domain.DefaultSecuritySettings() + if settings.SessionDurationHours != defaults.SessionDurationHours { + t.Errorf("session_duration_hours: got %d, want %d", settings.SessionDurationHours, defaults.SessionDurationHours) + } + if settings.MaxSessionsPerUser != defaults.MaxSessionsPerUser { + t.Errorf("max_sessions_per_user: got %d, want %d", settings.MaxSessionsPerUser, defaults.MaxSessionsPerUser) + } + if settings.PasswordMinLength != defaults.PasswordMinLength { + t.Errorf("password_min_length: got %d, want %d", settings.PasswordMinLength, defaults.PasswordMinLength) + } + if settings.PasswordMaxLength != defaults.PasswordMaxLength { + t.Errorf("password_max_length: got %d, want %d", settings.PasswordMaxLength, defaults.PasswordMaxLength) + } + if settings.LoginRateLimitPerMin != defaults.LoginRateLimitPerMin { + t.Errorf("login_rate_limit_per_minute: got %d, want %d", settings.LoginRateLimitPerMin, defaults.LoginRateLimitPerMin) + } + if settings.AccountLockoutAttempts != defaults.AccountLockoutAttempts { + t.Errorf("account_lockout_attempts: got %d, want %d", settings.AccountLockoutAttempts, defaults.AccountLockoutAttempts) + } +} + +func TestSettingsHandler_UpdateValid(t *testing.T) { + mux := setupSettingsServer() + + body := `{ + "session_duration_hours": 48, + "max_sessions_per_user": 5, + "password_min_length": 16, + "password_max_length": 64, + "login_rate_limit_per_minute": 10, + "account_lockout_attempts": 5, + "trusted_proxies": ["10.0.0.0/8"] + }` + req := httptest.NewRequest(stdhttp.MethodPatch, "/api/v1/settings/security", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusOK { + t.Fatalf("expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + var settings domain.SecuritySettings + if err := json.NewDecoder(rr.Body).Decode(&settings); err != nil { + t.Fatalf("decode error: %v", err) + } + + if settings.SessionDurationHours != 48 { + t.Errorf("session_duration_hours: got %d, want 48", settings.SessionDurationHours) + } + if settings.MaxSessionsPerUser != 5 { + t.Errorf("max_sessions_per_user: got %d, want 5", settings.MaxSessionsPerUser) + } + if settings.PasswordMinLength != 16 { + t.Errorf("password_min_length: got %d, want 16", settings.PasswordMinLength) + } + + // Verify GET returns updated values + getReq := httptest.NewRequest(stdhttp.MethodGet, "/api/v1/settings/security", nil) + getRR := httptest.NewRecorder() + mux.ServeHTTP(getRR, getReq) + + var updated domain.SecuritySettings + if err := json.NewDecoder(getRR.Body).Decode(&updated); err != nil { + t.Fatalf("decode error: %v", err) + } + if updated.SessionDurationHours != 48 { + t.Errorf("after update, session_duration_hours: got %d, want 48", updated.SessionDurationHours) + } +} + +func TestSettingsHandler_UpdateInvalidBounds(t *testing.T) { + mux := setupSettingsServer() + + tests := []struct { + name string + body string + }{ + { + name: "session_duration_hours too low", + body: `{"session_duration_hours":0,"max_sessions_per_user":10,"password_min_length":12,"password_max_length":72,"login_rate_limit_per_minute":5,"account_lockout_attempts":0}`, + }, + { + name: "session_duration_hours too high", + body: `{"session_duration_hours":721,"max_sessions_per_user":10,"password_min_length":12,"password_max_length":72,"login_rate_limit_per_minute":5,"account_lockout_attempts":0}`, + }, + { + name: "max_sessions_per_user too low", + body: `{"session_duration_hours":24,"max_sessions_per_user":0,"password_min_length":12,"password_max_length":72,"login_rate_limit_per_minute":5,"account_lockout_attempts":0}`, + }, + { + name: "password_min_length too low", + body: `{"session_duration_hours":24,"max_sessions_per_user":10,"password_min_length":7,"password_max_length":72,"login_rate_limit_per_minute":5,"account_lockout_attempts":0}`, + }, + { + name: "password_max_length less than min", + body: `{"session_duration_hours":24,"max_sessions_per_user":10,"password_min_length":16,"password_max_length":10,"login_rate_limit_per_minute":5,"account_lockout_attempts":0}`, + }, + { + name: "login_rate_limit_per_minute too low", + body: `{"session_duration_hours":24,"max_sessions_per_user":10,"password_min_length":12,"password_max_length":72,"login_rate_limit_per_minute":0,"account_lockout_attempts":0}`, + }, + { + name: "account_lockout_attempts too high", + body: `{"session_duration_hours":24,"max_sessions_per_user":10,"password_min_length":12,"password_max_length":72,"login_rate_limit_per_minute":5,"account_lockout_attempts":101}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(stdhttp.MethodPatch, "/api/v1/settings/security", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", rr.Code, rr.Body.String()) + } + }) + } +} diff --git a/internal/auth/rbac.go b/internal/auth/rbac.go index 32c0929..ade0de9 100644 --- a/internal/auth/rbac.go +++ b/internal/auth/rbac.go @@ -33,6 +33,7 @@ const ( ResourceAudit = "audit" ResourceUsers = "users" ResourceDiscovery = "discovery" + ResourceSettings = "settings" ) // Action constants for permission checks. @@ -42,6 +43,7 @@ const ( ActionUpdate = "update" ActionDelete = "delete" ActionList = "list" + ActionWrite = "write" ) // Permission represents an action on a resource. @@ -87,6 +89,8 @@ var RolePermissions = map[Role][]Permission{ {ResourceDiscovery, ActionUpdate}, {ResourceDiscovery, ActionDelete}, {ResourceDiscovery, ActionList}, + {ResourceSettings, ActionRead}, + {ResourceSettings, ActionWrite}, }, RoleOperator: { // Read/write access to pools, accounts, and discovery diff --git a/internal/domain/settings.go b/internal/domain/settings.go new file mode 100644 index 0000000..d430f40 --- /dev/null +++ b/internal/domain/settings.go @@ -0,0 +1,25 @@ +package domain + +// SecuritySettings holds runtime security configuration. +type SecuritySettings struct { + SessionDurationHours int `json:"session_duration_hours"` + MaxSessionsPerUser int `json:"max_sessions_per_user"` + PasswordMinLength int `json:"password_min_length"` + PasswordMaxLength int `json:"password_max_length"` + LoginRateLimitPerMin int `json:"login_rate_limit_per_minute"` + AccountLockoutAttempts int `json:"account_lockout_attempts"` + TrustedProxies []string `json:"trusted_proxies"` +} + +// DefaultSecuritySettings returns safe defaults. +func DefaultSecuritySettings() SecuritySettings { + return SecuritySettings{ + SessionDurationHours: 24, + MaxSessionsPerUser: 10, + PasswordMinLength: 12, + PasswordMaxLength: 72, + LoginRateLimitPerMin: 5, + AccountLockoutAttempts: 0, + TrustedProxies: []string{}, + } +} diff --git a/internal/storage/settings.go b/internal/storage/settings.go new file mode 100644 index 0000000..0ae5640 --- /dev/null +++ b/internal/storage/settings.go @@ -0,0 +1,13 @@ +package storage + +import ( + "context" + + "cloudpam/internal/domain" +) + +// SettingsStore manages application settings. +type SettingsStore interface { + GetSecuritySettings(ctx context.Context) (*domain.SecuritySettings, error) + UpdateSecuritySettings(ctx context.Context, settings *domain.SecuritySettings) error +} diff --git a/internal/storage/settings_memory.go b/internal/storage/settings_memory.go new file mode 100644 index 0000000..fb46a3e --- /dev/null +++ b/internal/storage/settings_memory.go @@ -0,0 +1,34 @@ +package storage + +import ( + "context" + "sync" + + "cloudpam/internal/domain" +) + +// MemorySettingsStore is an in-memory implementation of SettingsStore. +type MemorySettingsStore struct { + mu sync.RWMutex + security *domain.SecuritySettings +} + +// NewMemorySettingsStore creates a new in-memory settings store with defaults. +func NewMemorySettingsStore() *MemorySettingsStore { + defaults := domain.DefaultSecuritySettings() + return &MemorySettingsStore{security: &defaults} +} + +func (s *MemorySettingsStore) GetSecuritySettings(_ context.Context) (*domain.SecuritySettings, error) { + s.mu.RLock() + defer s.mu.RUnlock() + copy := *s.security + return ©, nil +} + +func (s *MemorySettingsStore) UpdateSecuritySettings(_ context.Context, settings *domain.SecuritySettings) error { + s.mu.Lock() + defer s.mu.Unlock() + s.security = settings + return nil +} diff --git a/internal/storage/sqlite/settings.go b/internal/storage/sqlite/settings.go new file mode 100644 index 0000000..7e572c9 --- /dev/null +++ b/internal/storage/sqlite/settings.go @@ -0,0 +1,42 @@ +//go:build sqlite + +package sqlite + +import ( + "context" + "database/sql" + "encoding/json" + + "cloudpam/internal/domain" +) + +// GetSecuritySettings retrieves security settings from the database. +func (s *Store) GetSecuritySettings(ctx context.Context) (*domain.SecuritySettings, error) { + var raw string + err := s.db.QueryRowContext(ctx, `SELECT value FROM settings WHERE key = 'security'`).Scan(&raw) + if err == sql.ErrNoRows { + defaults := domain.DefaultSecuritySettings() + return &defaults, nil + } + if err != nil { + return nil, err + } + var settings domain.SecuritySettings + if err := json.Unmarshal([]byte(raw), &settings); err != nil { + return nil, err + } + return &settings, nil +} + +// UpdateSecuritySettings saves security settings to the database. +func (s *Store) UpdateSecuritySettings(ctx context.Context, settings *domain.SecuritySettings) error { + raw, err := json.Marshal(settings) + if err != nil { + return err + } + _, err = s.db.ExecContext(ctx, + `INSERT INTO settings (key, value, updated_at) VALUES ('security', ?, datetime('now')) + ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at`, + string(raw)) + return err +} diff --git a/migrations/0016_settings.sql b/migrations/0016_settings.sql new file mode 100644 index 0000000..17eb760 --- /dev/null +++ b/migrations/0016_settings.sql @@ -0,0 +1,6 @@ +-- Application settings (key-value with JSON values) +CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); From 2c1fe472bef2b9bd55355e5b88b05e58667645d1 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:42:14 -0600 Subject: [PATCH 07/10] feat: session hardening and revoke-all-sessions API - Add ListByUserID to SessionStore interface (memory, SQLite, PostgreSQL) - Enforce max 10 concurrent sessions per user (evict oldest on overflow) - POST /api/v1/auth/users/{id}/revoke-sessions endpoint - Admin or self-service session revocation with audit logging Co-Authored-By: Claude Opus 4.6 --- internal/api/user_handlers.go | 60 ++++++ internal/api/user_handlers_test.go | 285 +++++++++++++++++++++++++++++ internal/auth/session.go | 16 ++ internal/auth/session_postgres.go | 21 +++ internal/auth/session_sqlite.go | 23 +++ 5 files changed, 405 insertions(+) create mode 100644 internal/api/user_handlers_test.go diff --git a/internal/api/user_handlers.go b/internal/api/user_handlers.go index 09e2656..09bffc0 100644 --- a/internal/api/user_handlers.go +++ b/internal/api/user_handlers.go @@ -5,6 +5,7 @@ import ( "encoding/json" "log/slog" "net/http" + "sort" "strings" "time" @@ -123,6 +124,17 @@ func (us *UserServer) RegisterProtectedUserRoutes(logger *slog.Logger, opts ...U return } + // Check for /revoke-sessions sub-route. + if len(parts) == 2 && strings.TrimSuffix(parts[1], "/") == "revoke-sessions" { + if r.Method == http.MethodPost { + us.handleRevokeSessions(w, r, id) + return + } + w.Header().Set("Allow", http.MethodPost) + us.writeErr(r.Context(), w, http.StatusMethodNotAllowed, "method not allowed", "") + return + } + switch r.Method { case http.MethodGet: usersReadMW(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -206,6 +218,21 @@ func (us *UserServer) handleLogin(w http.ResponseWriter, r *http.Request) { return } + // Enforce max concurrent sessions per user. + if sessions, err := us.sessionStore.ListByUserID(ctx, user.ID); err == nil { + const maxSessions = 10 + if len(sessions) > maxSessions { + // Sort by creation time (oldest first) and evict excess. + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].CreatedAt.Before(sessions[j].CreatedAt) + }) + excess := len(sessions) - maxSessions + for i := 0; i < excess; i++ { + _ = us.sessionStore.Delete(ctx, sessions[i].ID) + } + } + } + // Update last login. now := time.Now().UTC() _ = us.userStore.UpdateLastLogin(ctx, user.ID, now) @@ -352,6 +379,16 @@ func (us *UserServer) handleUserByID(w http.ResponseWriter, r *http.Request) { return } + if len(parts) == 2 && strings.TrimSuffix(parts[1], "/") == "revoke-sessions" { + if r.Method == http.MethodPost { + us.handleRevokeSessions(w, r, id) + return + } + w.Header().Set("Allow", http.MethodPost) + us.writeErr(r.Context(), w, http.StatusMethodNotAllowed, "method not allowed", "") + return + } + switch r.Method { case http.MethodGet: us.getUser(w, r, id) @@ -643,6 +680,29 @@ func (us *UserServer) changePassword(w http.ResponseWriter, r *http.Request, id w.WriteHeader(http.StatusNoContent) } +// handleRevokeSessions revokes all sessions for a given user. +// POST /api/v1/auth/users/{id}/revoke-sessions +// Accessible by admins or the user themselves. +func (us *UserServer) handleRevokeSessions(w http.ResponseWriter, r *http.Request, id string) { + ctx := r.Context() + + // Check authorization: admin or self. + callerRole := auth.GetEffectiveRole(ctx) + caller := auth.UserFromContext(ctx) + if callerRole != auth.RoleAdmin && (caller == nil || caller.ID != id) { + us.writeErr(ctx, w, http.StatusForbidden, "forbidden", "only admins or the user themselves can revoke sessions") + return + } + + if err := us.sessionStore.DeleteByUserID(ctx, id); err != nil { + us.writeErr(ctx, w, http.StatusInternalServerError, "failed to revoke sessions", err.Error()) + return + } + + us.logAuditEvent(ctx, "revoke_sessions", audit.ResourceUser, id, "", http.StatusOK) + writeJSON(w, http.StatusOK, map[string]string{"status": "sessions revoked"}) +} + // logAuditEvent is a helper to log audit events with actor context. func (us *UserServer) logAuditEvent(ctx context.Context, action, resourceType, resourceID, resourceName string, statusCode int) { if us.auditLogger == nil { diff --git a/internal/api/user_handlers_test.go b/internal/api/user_handlers_test.go new file mode 100644 index 0000000..cae7e19 --- /dev/null +++ b/internal/api/user_handlers_test.go @@ -0,0 +1,285 @@ +package api + +import ( + "context" + "encoding/json" + "io" + stdhttp "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "cloudpam/internal/audit" + "cloudpam/internal/auth" + "cloudpam/internal/observability" + "cloudpam/internal/storage" +) + +func setupUserTestServer() (*UserServer, auth.SessionStore, auth.UserStore) { + st := storage.NewMemoryStore() + mux := stdhttp.NewServeMux() + logger := observability.NewLogger(observability.Config{ + Level: "info", + Format: "json", + Output: io.Discard, + }) + auditLogger := audit.NewMemoryAuditLogger() + srv := NewServer(mux, st, logger, nil, auditLogger) + + keyStore := auth.NewMemoryKeyStore() + sessionStore := auth.NewMemorySessionStore() + userStore := auth.NewMemoryUserStore() + + us := NewUserServer(srv, keyStore, userStore, sessionStore, auditLogger) + srv.registerUnprotectedTestRoutes() + us.RegisterUserRoutes() + + return us, sessionStore, userStore +} + +func TestRevokeSessions_Success(t *testing.T) { + us, sessionStore, userStore := setupUserTestServer() + + // Create a test user. + ctx := context.Background() + hash, err := auth.HashPassword("TestPass123!") + if err != nil { + t.Fatalf("hash password: %v", err) + } + user := &auth.User{ + ID: "user-1", + Username: "testuser", + Email: "test@example.com", + Role: auth.RoleViewer, + PasswordHash: hash, + IsActive: true, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if err := userStore.Create(ctx, user); err != nil { + t.Fatalf("create user: %v", err) + } + + // Create 3 sessions for the user. + for i := 0; i < 3; i++ { + sess, err := auth.NewSession(user.ID, user.Role, auth.DefaultSessionDuration, nil) + if err != nil { + t.Fatalf("new session: %v", err) + } + if err := sessionStore.Create(ctx, sess); err != nil { + t.Fatalf("create session: %v", err) + } + } + + // Verify 3 sessions exist. + sessions, err := sessionStore.ListByUserID(ctx, user.ID) + if err != nil { + t.Fatalf("list sessions: %v", err) + } + if len(sessions) != 3 { + t.Fatalf("expected 3 sessions, got %d", len(sessions)) + } + + // Call revoke-sessions endpoint (as admin context). + req := httptest.NewRequest(stdhttp.MethodPost, "/api/v1/auth/users/"+user.ID+"/revoke-sessions", nil) + req = req.WithContext(auth.ContextWithRole(auth.ContextWithUser(req.Context(), user), auth.RoleAdmin)) + rr := httptest.NewRecorder() + us.mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusOK { + t.Fatalf("expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + // Verify response. + var resp map[string]string + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp["status"] != "sessions revoked" { + t.Errorf("expected status 'sessions revoked', got %q", resp["status"]) + } + + // Verify all sessions are deleted. + sessions, err = sessionStore.ListByUserID(ctx, user.ID) + if err != nil { + t.Fatalf("list sessions after revoke: %v", err) + } + if len(sessions) != 0 { + t.Errorf("expected 0 sessions after revoke, got %d", len(sessions)) + } +} + +func TestRevokeSessions_SelfService(t *testing.T) { + us, sessionStore, userStore := setupUserTestServer() + + ctx := context.Background() + hash, _ := auth.HashPassword("TestPass123!") + user := &auth.User{ + ID: "user-2", + Username: "selfuser", + Email: "self@example.com", + Role: auth.RoleViewer, + PasswordHash: hash, + IsActive: true, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if err := userStore.Create(ctx, user); err != nil { + t.Fatalf("create user: %v", err) + } + + // Create a session. + sess, _ := auth.NewSession(user.ID, user.Role, auth.DefaultSessionDuration, nil) + _ = sessionStore.Create(ctx, sess) + + // User revokes own sessions (non-admin). + req := httptest.NewRequest(stdhttp.MethodPost, "/api/v1/auth/users/"+user.ID+"/revoke-sessions", nil) + req = req.WithContext(auth.ContextWithRole(auth.ContextWithUser(req.Context(), user), auth.RoleViewer)) + rr := httptest.NewRecorder() + us.mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusOK { + t.Fatalf("expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + sessions, _ := sessionStore.ListByUserID(ctx, user.ID) + if len(sessions) != 0 { + t.Errorf("expected 0 sessions, got %d", len(sessions)) + } +} + +func TestRevokeSessions_Forbidden(t *testing.T) { + us, sessionStore, userStore := setupUserTestServer() + + ctx := context.Background() + hash, _ := auth.HashPassword("TestPass123!") + + targetUser := &auth.User{ + ID: "user-target", + Username: "target", + Role: auth.RoleViewer, + PasswordHash: hash, + IsActive: true, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + callerUser := &auth.User{ + ID: "user-caller", + Username: "caller", + Role: auth.RoleViewer, + PasswordHash: hash, + IsActive: true, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + _ = userStore.Create(ctx, targetUser) + _ = userStore.Create(ctx, callerUser) + + sess, _ := auth.NewSession(targetUser.ID, targetUser.Role, auth.DefaultSessionDuration, nil) + _ = sessionStore.Create(ctx, sess) + + // Non-admin caller trying to revoke another user's sessions. + req := httptest.NewRequest(stdhttp.MethodPost, "/api/v1/auth/users/"+targetUser.ID+"/revoke-sessions", nil) + req = req.WithContext(auth.ContextWithRole(auth.ContextWithUser(req.Context(), callerUser), auth.RoleViewer)) + rr := httptest.NewRecorder() + us.mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusForbidden { + t.Fatalf("expected 403, got %d: %s", rr.Code, rr.Body.String()) + } + + // Verify session still exists. + sessions, _ := sessionStore.ListByUserID(ctx, targetUser.ID) + if len(sessions) != 1 { + t.Errorf("expected 1 session (not revoked), got %d", len(sessions)) + } +} + +func TestListByUserID(t *testing.T) { + store := auth.NewMemorySessionStore() + ctx := context.Background() + + // Create sessions for two users. + for i := 0; i < 3; i++ { + sess, _ := auth.NewSession("user-a", auth.RoleViewer, auth.DefaultSessionDuration, nil) + _ = store.Create(ctx, sess) + } + for i := 0; i < 2; i++ { + sess, _ := auth.NewSession("user-b", auth.RoleViewer, auth.DefaultSessionDuration, nil) + _ = store.Create(ctx, sess) + } + + sessionsA, err := store.ListByUserID(ctx, "user-a") + if err != nil { + t.Fatalf("ListByUserID: %v", err) + } + if len(sessionsA) != 3 { + t.Errorf("expected 3 sessions for user-a, got %d", len(sessionsA)) + } + + sessionsB, err := store.ListByUserID(ctx, "user-b") + if err != nil { + t.Fatalf("ListByUserID: %v", err) + } + if len(sessionsB) != 2 { + t.Errorf("expected 2 sessions for user-b, got %d", len(sessionsB)) + } + + // Non-existent user returns empty. + sessionsC, err := store.ListByUserID(ctx, "user-c") + if err != nil { + t.Fatalf("ListByUserID: %v", err) + } + if len(sessionsC) != 0 { + t.Errorf("expected 0 sessions for user-c, got %d", len(sessionsC)) + } +} + +func TestSessionLimit_Enforcement(t *testing.T) { + us, sessionStore, userStore := setupUserTestServer() + + ctx := context.Background() + hash, _ := auth.HashPassword("TestPass123!") + user := &auth.User{ + ID: "user-limit", + Username: "limituser", + Email: "limit@example.com", + Role: auth.RoleViewer, + PasswordHash: hash, + IsActive: true, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if err := userStore.Create(ctx, user); err != nil { + t.Fatalf("create user: %v", err) + } + + // Pre-create 11 sessions (over the limit of 10). + for i := 0; i < 11; i++ { + sess, _ := auth.NewSession(user.ID, user.Role, auth.DefaultSessionDuration, nil) + // Stagger creation times so oldest can be identified. + sess.CreatedAt = time.Now().UTC().Add(time.Duration(i) * time.Second) + _ = sessionStore.Create(ctx, sess) + } + + // Login, which should trigger eviction. + body := `{"username":"limituser","password":"TestPass123!"}` + req := httptest.NewRequest(stdhttp.MethodPost, "/api/v1/auth/login", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + us.mux.ServeHTTP(rr, req) + + if rr.Code != stdhttp.StatusOK { + t.Fatalf("login: expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + + // After login: 11 existing + 1 new = 12, then evict 2 oldest => 10. + sessions, err := sessionStore.ListByUserID(ctx, user.ID) + if err != nil { + t.Fatalf("list sessions: %v", err) + } + if len(sessions) != 10 { + t.Errorf("expected 10 sessions after eviction, got %d", len(sessions)) + } +} diff --git a/internal/auth/session.go b/internal/auth/session.go index 4263d26..3e247b8 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -73,6 +73,9 @@ type SessionStore interface { // DeleteByUserID removes all sessions for a specific user. DeleteByUserID(ctx context.Context, userID string) error + // ListByUserID returns all valid (non-expired) sessions for a specific user. + ListByUserID(ctx context.Context, userID string) ([]*Session, error) + // Cleanup removes all expired sessions. // Returns the number of sessions removed. Cleanup(ctx context.Context) (int, error) @@ -200,6 +203,19 @@ func (s *MemorySessionStore) DeleteByUserID(_ context.Context, userID string) er return nil } +// ListByUserID returns all valid (non-expired) sessions for a specific user. +func (s *MemorySessionStore) ListByUserID(_ context.Context, userID string) ([]*Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + var result []*Session + for _, sess := range s.sessions { + if sess.UserID == userID && sess.IsValid() { + result = append(result, copySession(sess)) + } + } + return result, nil +} + // Cleanup removes all expired sessions. // Returns the number of sessions removed. func (s *MemorySessionStore) Cleanup(_ context.Context) (int, error) { diff --git a/internal/auth/session_postgres.go b/internal/auth/session_postgres.go index 57279e6..34c70dc 100644 --- a/internal/auth/session_postgres.go +++ b/internal/auth/session_postgres.go @@ -114,6 +114,27 @@ func (s *PostgresSessionStore) DeleteByUserID(ctx context.Context, userID string return err } +func (s *PostgresSessionStore) ListByUserID(ctx context.Context, userID string) ([]*Session, error) { + rows, err := s.pool.Query(ctx, + `SELECT id, user_id, role, created_at, expires_at FROM sessions WHERE user_id = $1 AND expires_at > NOW() ORDER BY created_at ASC`, + userID) + if err != nil { + return nil, err + } + defer rows.Close() + var sessions []*Session + for rows.Next() { + var sess Session + var role string + if err := rows.Scan(&sess.ID, &sess.UserID, &role, &sess.CreatedAt, &sess.ExpiresAt); err != nil { + return nil, err + } + sess.Role = Role(role) + sessions = append(sessions, &sess) + } + return sessions, rows.Err() +} + func (s *PostgresSessionStore) Cleanup(ctx context.Context) (int, error) { tag, err := s.pool.Exec(ctx, `DELETE FROM sessions WHERE expires_at < $1`, time.Now().UTC()) if err != nil { diff --git a/internal/auth/session_sqlite.go b/internal/auth/session_sqlite.go index 71e9375..667236f 100644 --- a/internal/auth/session_sqlite.go +++ b/internal/auth/session_sqlite.go @@ -130,6 +130,29 @@ func (s *SQLiteSessionStore) DeleteByUserID(ctx context.Context, userID string) return nil } +func (s *SQLiteSessionStore) ListByUserID(ctx context.Context, userID string) ([]*Session, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, user_id, role, created_at, expires_at FROM sessions WHERE user_id = ? AND expires_at > datetime('now') ORDER BY created_at ASC`, + userID) + if err != nil { + return nil, err + } + defer rows.Close() + var sessions []*Session + for rows.Next() { + var sess Session + var role, createdAt, expiresAt string + if err := rows.Scan(&sess.ID, &sess.UserID, &role, &createdAt, &expiresAt); err != nil { + return nil, err + } + sess.Role = Role(role) + sess.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) + sess.ExpiresAt, _ = time.Parse(time.RFC3339, expiresAt) + sessions = append(sessions, &sess) + } + return sessions, rows.Err() +} + func (s *SQLiteSessionStore) Cleanup(ctx context.Context) (int, error) { res, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE expires_at < ?`, time.Now().UTC().Format(time.RFC3339Nano)) From f79111683ee066d14f5804717e87c0a0350790c0 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:48:53 -0600 Subject: [PATCH 08/10] feat: CSRF protection middleware for session-authenticated requests - Double-submit cookie pattern: csrf_token cookie + X-CSRF-Token header - API key requests exempt (no cookies = no CSRF risk) - Login and setup endpoints exempt (no session yet) - Frontend API client sends CSRF token on all state-changing requests Co-Authored-By: Claude Opus 4.6 --- cmd/cloudpam/main.go | 1 + internal/api/csrf.go | 73 +++++++++++++ internal/api/csrf_test.go | 223 ++++++++++++++++++++++++++++++++++++++ ui/src/api/client.ts | 25 ++++- 4 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 internal/api/csrf.go create mode 100644 internal/api/csrf_test.go diff --git a/cmd/cloudpam/main.go b/cmd/cloudpam/main.go index 13f524d..e476486 100644 --- a/cmd/cloudpam/main.go +++ b/cmd/cloudpam/main.go @@ -232,6 +232,7 @@ func main() { observability.MetricsMiddleware(metrics), api.RequestIDMiddleware(), api.LoggingMiddleware(logger.Slog()), + api.CSRFMiddleware(), api.RateLimitMiddleware(rateCfg, logger.Slog()), ) server := &http.Server{ diff --git a/internal/api/csrf.go b/internal/api/csrf.go new file mode 100644 index 0000000..3a2c70f --- /dev/null +++ b/internal/api/csrf.go @@ -0,0 +1,73 @@ +package api + +import ( + "crypto/rand" + "encoding/hex" + "net/http" + "strings" +) + +const ( + csrfTokenLength = 32 + csrfHeaderName = "X-CSRF-Token" + csrfCookieName = "csrf_token" +) + +// CSRFMiddleware adds CSRF protection for session-authenticated state-changing requests. +// API key authenticated requests are exempt (no cookies = no CSRF risk). +// Login and setup endpoints are exempt (no session yet). +func CSRFMiddleware() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip for safe methods + if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { + // Set CSRF token cookie if not present + if _, err := r.Cookie(csrfCookieName); err != nil { + token := generateCSRFToken() + http.SetCookie(w, &http.Cookie{ + Name: csrfCookieName, + Value: token, + Path: "/", + HttpOnly: false, // JS needs to read it + Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https", + SameSite: http.SameSiteLaxMode, + }) + } + next.ServeHTTP(w, r) + return + } + + // For state-changing methods: skip CSRF check if using API key auth + if authHeader := r.Header.Get("Authorization"); authHeader != "" && strings.HasPrefix(authHeader, "Bearer cpam_") { + next.ServeHTTP(w, r) + return + } + + // Skip CSRF for login and setup endpoints (no session yet) + if r.URL.Path == "/api/v1/auth/login" || r.URL.Path == "/api/v1/auth/setup" { + next.ServeHTTP(w, r) + return + } + + // Validate CSRF token from header matches cookie + cookie, err := r.Cookie(csrfCookieName) + if err != nil { + writeJSON(w, http.StatusForbidden, apiError{Error: "CSRF token missing", Detail: "csrf_token cookie required"}) + return + } + headerToken := r.Header.Get(csrfHeaderName) + if headerToken == "" || headerToken != cookie.Value { + writeJSON(w, http.StatusForbidden, apiError{Error: "CSRF token invalid", Detail: "X-CSRF-Token header must match csrf_token cookie"}) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func generateCSRFToken() string { + b := make([]byte, csrfTokenLength) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} diff --git a/internal/api/csrf_test.go b/internal/api/csrf_test.go new file mode 100644 index 0000000..9e79333 --- /dev/null +++ b/internal/api/csrf_test.go @@ -0,0 +1,223 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestCSRFMiddleware_GETSetsTokenCookie(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/pools", nil) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + // Check that csrf_token cookie was set + var found bool + for _, c := range rr.Result().Cookies() { + if c.Name == csrfCookieName { + found = true + if c.Value == "" { + t.Fatal("csrf_token cookie value is empty") + } + if c.HttpOnly { + t.Fatal("csrf_token cookie should not be HttpOnly (JS needs to read it)") + } + break + } + } + if !found { + t.Fatal("expected csrf_token cookie to be set on GET request") + } +} + +func TestCSRFMiddleware_GETDoesNotSetCookieIfAlreadyPresent(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/pools", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: "existing-token"}) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + // Should not set a new cookie since one already exists + for _, c := range rr.Result().Cookies() { + if c.Name == csrfCookieName { + t.Fatal("should not set csrf_token cookie when one already exists") + } + } +} + +func TestCSRFMiddleware_POSTWithoutCSRFTokenReturns403(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/pools", nil) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) + } +} + +func TestCSRFMiddleware_POSTWithValidCSRFTokenSucceeds(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + token := "test-csrf-token-value" + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/pools", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token}) + req.Header.Set(csrfHeaderName, token) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } +} + +func TestCSRFMiddleware_POSTWithMismatchedTokenReturns403(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/pools", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: "cookie-token"}) + req.Header.Set(csrfHeaderName, "different-header-token") + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) + } +} + +func TestCSRFMiddleware_POSTWithAPIKeyBypassesCSRF(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/pools", nil) + req.Header.Set("Authorization", "Bearer cpam_testkey123abc") + // No CSRF token provided - should still succeed + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 for API key auth (CSRF bypass), got %d", rr.Code) + } +} + +func TestCSRFMiddleware_POSTToLoginBypassesCSRF(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 for login bypass, got %d", rr.Code) + } +} + +func TestCSRFMiddleware_POSTToSetupBypassesCSRF(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/setup", nil) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 for setup bypass, got %d", rr.Code) + } +} + +func TestCSRFMiddleware_DELETERequiresCSRFToken(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Without token - should fail + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/api/v1/pools/1", nil) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) + } + + // With valid token - should succeed + token := "delete-csrf-token" + rr2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodDelete, "/api/v1/pools/1", nil) + req2.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token}) + req2.Header.Set(csrfHeaderName, token) + handler.ServeHTTP(rr2, req2) + + if rr2.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr2.Code) + } +} + +func TestCSRFMiddleware_PATCHRequiresCSRFToken(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Without token - should fail + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPatch, "/api/v1/pools/1", nil) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) + } +} + +func TestCSRFMiddleware_POSTWithCookieButNoHeaderReturns403(t *testing.T) { + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/pools", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: "some-token"}) + // No X-CSRF-Token header + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) + } +} + +func TestGenerateCSRFToken(t *testing.T) { + token := generateCSRFToken() + if len(token) != csrfTokenLength*2 { // hex encoding doubles length + t.Fatalf("expected token length %d, got %d", csrfTokenLength*2, len(token)) + } + + // Ensure tokens are unique + token2 := generateCSRFToken() + if token == token2 { + t.Fatal("expected unique tokens, got identical values") + } +} diff --git a/ui/src/api/client.ts b/ui/src/api/client.ts index 58584c4..c15dbf5 100644 --- a/ui/src/api/client.ts +++ b/ui/src/api/client.ts @@ -10,10 +10,25 @@ export class ApiRequestError extends Error { } } +function getCSRFToken(): string | null { + const match = document.cookie.match(/(?:^|;\s*)csrf_token=([^;]+)/) + return match ? match[1] : null +} + async function request(path: string, options?: RequestInit): Promise { + const headers: Record = { 'Content-Type': 'application/json' } + + // Add CSRF token for state-changing requests + if (options?.method && options.method !== 'GET' && options.method !== 'HEAD') { + const csrfToken = getCSRFToken() + if (csrfToken) { + headers['X-CSRF-Token'] = csrfToken + } + } + const res = await fetch(path, { credentials: 'same-origin', - headers: { 'Content-Type': 'application/json' }, + headers, ...options, }) @@ -67,10 +82,16 @@ export interface SSECallbacks { } export async function streamPost(path: string, data: unknown, callbacks: SSECallbacks): Promise { + const streamHeaders: Record = { 'Content-Type': 'application/json' } + const csrfToken = getCSRFToken() + if (csrfToken) { + streamHeaders['X-CSRF-Token'] = csrfToken + } + const res = await fetch(path, { method: 'POST', credentials: 'same-origin', - headers: { 'Content-Type': 'application/json' }, + headers: streamHeaders, body: JSON.stringify(data), }) From 26ebfd530bfdd2ec4ff3e3ecf059512d89ea1934 Mon Sep 17 00:00:00 2001 From: BadgerOps Date: Wed, 18 Feb 2026 12:49:04 -0600 Subject: [PATCH 09/10] feat: security settings UI page under Config > Security - SecuritySettingsPage with session, password, login, network sections - useSecuritySettings hook for GET/PATCH /api/v1/settings/security - Sidebar link with Shield icon (admin only) - Coming soon placeholders for Roles & Permissions and SSO/OIDC Co-Authored-By: Claude Opus 4.6 --- ui/src/App.tsx | 2 + ui/src/components/Sidebar.tsx | 7 + ui/src/hooks/useSettings.ts | 41 ++++ ui/src/pages/SecuritySettingsPage.tsx | 263 ++++++++++++++++++++++++++ web/dist/index.html | 6 +- 5 files changed, 316 insertions(+), 3 deletions(-) create mode 100644 ui/src/hooks/useSettings.ts create mode 100644 ui/src/pages/SecuritySettingsPage.tsx diff --git a/ui/src/App.tsx b/ui/src/App.tsx index ef0d473..3fcd0c4 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -18,6 +18,7 @@ import RecommendationsPage from './pages/RecommendationsPage' import AIPlannerPage from './pages/AIPlannerPage' import ProfilePage from './pages/ProfilePage' import LogDestinationsPage from './pages/LogDestinationsPage' +import SecuritySettingsPage from './pages/SecuritySettingsPage' export default function App() { const toastState = useToastState() @@ -45,6 +46,7 @@ export default function App() { } /> } /> } /> + } /> diff --git a/ui/src/components/Sidebar.tsx b/ui/src/components/Sidebar.tsx index e9c767c..a04de4f 100644 --- a/ui/src/components/Sidebar.tsx +++ b/ui/src/components/Sidebar.tsx @@ -14,6 +14,7 @@ import { Lightbulb, Bot, Radio, + Shield, } from 'lucide-react' import { useAuth } from '../hooks/useAuth' @@ -116,6 +117,12 @@ export default function Sidebar({ onImportExport }: SidebarProps) { Log Destinations + {role === 'admin' && ( + + + Security + + )} diff --git a/ui/src/hooks/useSettings.ts b/ui/src/hooks/useSettings.ts new file mode 100644 index 0000000..9fce167 --- /dev/null +++ b/ui/src/hooks/useSettings.ts @@ -0,0 +1,41 @@ +import { useState, useEffect, useCallback } from 'react' +import { get, patch } from '../api/client' + +export interface SecuritySettings { + session_duration_hours: number + max_sessions_per_user: number + password_min_length: number + password_max_length: number + login_rate_limit_per_minute: number + account_lockout_attempts: number + trusted_proxies: string[] +} + +export function useSecuritySettings() { + const [settings, setSettings] = useState(null) + const [loading, setLoading] = useState(true) + const [error, setError] = useState(null) + + const fetchSettings = useCallback(async () => { + try { + setLoading(true) + const data = await get('/api/v1/settings/security') + setSettings(data) + setError(null) + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to load security settings') + } finally { + setLoading(false) + } + }, []) + + const updateSettings = useCallback(async (updated: SecuritySettings) => { + const data = await patch('/api/v1/settings/security', updated) + setSettings(data) + return data + }, []) + + useEffect(() => { fetchSettings() }, [fetchSettings]) + + return { settings, loading, error, updateSettings, refetch: fetchSettings } +} diff --git a/ui/src/pages/SecuritySettingsPage.tsx b/ui/src/pages/SecuritySettingsPage.tsx new file mode 100644 index 0000000..59e8c37 --- /dev/null +++ b/ui/src/pages/SecuritySettingsPage.tsx @@ -0,0 +1,263 @@ +import { useState, useEffect } from 'react' +import { Shield, Lock, Key, Globe, AlertCircle, Loader2, Users, Fingerprint } from 'lucide-react' +import { useSecuritySettings } from '../hooks/useSettings' +import type { SecuritySettings } from '../hooks/useSettings' +import { useToast } from '../hooks/useToast' + +export default function SecuritySettingsPage() { + const { settings, loading, error, updateSettings } = useSecuritySettings() + const { showToast } = useToast() + const [form, setForm] = useState(null) + const [saving, setSaving] = useState(false) + const [trustedProxiesText, setTrustedProxiesText] = useState('') + + useEffect(() => { + if (settings) { + setForm(settings) + setTrustedProxiesText((settings.trusted_proxies ?? []).join('\n')) + } + }, [settings]) + + async function handleSave() { + if (!form) return + setSaving(true) + try { + const proxies = trustedProxiesText + .split('\n') + .map(l => l.trim()) + .filter(l => l.length > 0) + await updateSettings({ ...form, trusted_proxies: proxies }) + showToast('Security settings saved', 'success') + } catch (err) { + showToast(err instanceof Error ? err.message : 'Failed to save settings', 'error') + } finally { + setSaving(false) + } + } + + function updateField(key: K, value: SecuritySettings[K]) { + setForm(prev => prev ? { ...prev, [key]: value } : prev) + } + + if (loading) { + return ( +
+ + Loading security settings... +
+ ) + } + + if (error) { + return ( +
+ + {error} +
+ ) + } + + if (!form) return null + + return ( +
+
+
+

+ + Security Settings +

+

+ Configure authentication, session, and password policies +

+
+ +
+ +
+ {/* Session Management */} +
+

+ + Session Management +

+
+
+ + updateField('session_duration_hours', parseInt(e.target.value) || 1)} + className="w-full px-3 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500" + /> +

+ How long a session cookie remains valid before requiring re-login +

+
+
+ + updateField('max_sessions_per_user', parseInt(e.target.value) || 1)} + className="w-full px-3 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500" + /> +

+ Maximum concurrent sessions allowed per user account +

+
+
+
+ + {/* Password Policy */} +
+

+ + Password Policy +

+
+
+ + updateField('password_min_length', parseInt(e.target.value) || 8)} + className="w-full px-3 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500" + /> +

+ Minimum number of characters required for passwords +

+
+
+ + updateField('password_max_length', Math.min(72, parseInt(e.target.value) || 72))} + className="w-full px-3 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500" + /> +

+ Maximum password length (capped at 72 for bcrypt compatibility) +

+
+
+
+ + {/* Login Protection */} +
+

+ + Login Protection +

+
+
+ + updateField('login_rate_limit_per_minute', parseInt(e.target.value) || 5)} + className="w-full px-3 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500" + /> +

+ Maximum login attempts per IP address per minute +

+
+
+ + updateField('account_lockout_attempts', parseInt(e.target.value) || 0)} + className="w-full px-3 py-2 border rounded-lg text-sm dark:bg-gray-700 dark:border-gray-600 dark:text-white focus:ring-2 focus:ring-blue-500" + /> +

+ Lock account after this many failed attempts (0 = disabled) +

+
+
+
+ + {/* Network */} +
+

+ + Network +

+
+ +