diff --git a/CLAUDE.md b/CLAUDE.md index 48a3f37..426fac7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -43,7 +43,7 @@ go test -v ./internal/output/... -run TestHumanFormatter ### Core Packages -- `internal/auth/` - Authentication provider supporting two modes. JWT (priority): client credentials exchange at `/api/v1/authenticate`, auto-refresh 5min before expiry, tenant ID extracted from `customer_id` JWT claim. Basic (fallback): static token + explicit tenant ID. Implements `AuthHeaderProvider` interface used by the API client. +- `internal/auth/` - Authentication provider supporting two modes. JWT (priority): client credentials exchange at `/api/v1/auth/token`, auto-refresh 5min before expiry, tenant ID extracted from `customer_id` JWT claim. Basic (fallback): static token + explicit tenant ID. Implements `AuthHeaderProvider` interface used by the API client. - `internal/api/` - API client for Armis Cloud. Two HTTP clients: one for general calls (60s timeout), one for uploads (streaming, no timeout, no retry). Functional options pattern (`WithHTTPClient()`, `WithUploadHTTPClient()`, `WithAllowLocalURLs()`). Upload uses `io.Pipe` streaming to avoid OOM on large files. Enforces HTTPS, validates presigned S3 URLs against SSRF. - `internal/model/` - Data structures: `Finding` (23 fields), `ScanResult`, `Summary`, `Fix`, `FindingValidation` (with taint/reachability analysis), API response types (`NormalizedFinding`, pagination). - `internal/output/` - Output formatters (human, json, sarif, junit) implementing the `Formatter` interface. `styles.go` defines ~50 lipgloss styles using Tailwind CSS color palette. `icons.go` defines Unicode constants (severity dots, box-drawing chars). `SyncColors()` switches between full-color and plain styles based on `cli.ColorsEnabled()`. @@ -83,9 +83,10 @@ go test -v ./internal/output/... -run TestHumanFormatter - `ARMIS_CLIENT_ID` - Client ID for JWT authentication (recommended) - `ARMIS_CLIENT_SECRET` - Client secret for JWT authentication -- `ARMIS_AUTH_ENDPOINT` - JWT authentication service endpoint URL - `ARMIS_API_TOKEN` - API token for Basic authentication (fallback) - `ARMIS_TENANT_ID` - Tenant identifier (required only with Basic auth; JWT extracts it from token) +- `ARMIS_API_URL` - Override base URL for Armis API (advanced; defaults based on --dev flag) +- `ARMIS_REGION` - Override Armis cloud region (equivalent to `--region`; used for region-aware authentication) - `ARMIS_FORMAT` - Default output format - `ARMIS_PAGE_LIMIT` - Results pagination size - `ARMIS_THEME` - Terminal background theme: auto, dark, light (default: auto) diff --git a/docs/CI-INTEGRATION.md b/docs/CI-INTEGRATION.md index fbbc2cf..0c4272c 100644 --- a/docs/CI-INTEGRATION.md +++ b/docs/CI-INTEGRATION.md @@ -697,7 +697,7 @@ Configure `ARMIS_API_TOKEN` and `ARMIS_TENANT_ID` as [secured repository variabl #### "authentication required" - No valid authentication credentials were provided -- Set `ARMIS_API_TOKEN` and `ARMIS_TENANT_ID` environment variables or secrets +- Set `ARMIS_CLIENT_ID` and `ARMIS_CLIENT_SECRET` for JWT auth (recommended), or `ARMIS_API_TOKEN` and `ARMIS_TENANT_ID` for legacy auth #### "tenant ID required" diff --git a/docs/FEATURES.md b/docs/FEATURES.md index 49e3dd5..7ddd238 100644 --- a/docs/FEATURES.md +++ b/docs/FEATURES.md @@ -294,9 +294,10 @@ armis-cli scan repo . \ |----------|-------------| | `ARMIS_CLIENT_ID` | Client ID for JWT authentication | | `ARMIS_CLIENT_SECRET` | Client secret for JWT authentication | -| `ARMIS_AUTH_ENDPOINT` | Authentication service endpoint URL | | `ARMIS_API_TOKEN` | API token for Basic authentication | | `ARMIS_TENANT_ID` | Tenant identifier (required for Basic auth only) | +| `ARMIS_API_URL` | Override base URL for Armis API and authentication (advanced) | +| `ARMIS_REGION` | Authentication region override (advanced; corresponds to `--region` flag) | **General:** diff --git a/internal/api/client.go b/internal/api/client.go index a08c289..4dc47d8 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -216,12 +216,9 @@ func (c *Client) IsDebug() bool { // request URL uses HTTPS (or localhost for testing). This prevents credential // exposure over insecure channels. // -// For JWT auth: sends raw JWT token (no "Bearer" prefix) +// For JWT auth: sends "Bearer " per RFC 6750 // For Basic auth: sends "Basic " per RFC 7617 // -// NOTE: The backend expects raw JWT tokens without the "Bearer" prefix. -// This is unconventional but matches the backend API contract. -// // SECURITY NOTE: The localhost/127.0.0.1 exception is intentional for local // development and testing environments where HTTPS certificates are not available. // Production deployments must always use HTTPS. diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 22d0f36..16ceccd 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -7,7 +7,9 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" + "net/http" "strings" "sync" "time" @@ -18,7 +20,8 @@ type AuthConfig struct { // JWT auth credentials ClientID string ClientSecret string //nolint:gosec // G117: This is a config field name, not a secret value - AuthEndpoint string // Full URL to the authentication service + BaseURL string // Moose API base URL (dev or prod) + Region string // Optional region override - bypasses auto-discovery if set // Legacy Basic auth Token string @@ -33,21 +36,24 @@ type JWTCredentials struct { Token string TenantID string // Extracted from customer_id claim ExpiresAt time.Time + Region string // Deployment region (e.g., "us1", "eu1", "au1") } // AuthProvider manages authentication tokens with automatic refresh. // It supports both JWT authentication and legacy Basic authentication. // For JWT auth, tokens are automatically refreshed when within 5 minutes of expiry. type AuthProvider struct { - config AuthConfig - credentials *JWTCredentials - authClient *AuthClient - mu sync.RWMutex - isLegacy bool // true if using Basic auth (--token) + config AuthConfig + credentials *JWTCredentials + authClient *AuthClient + mu sync.RWMutex + isLegacy bool // true if using Basic auth (--token) + cachedRegion string // memoized region from disk cache (loaded once) + regionLoaded bool // true if cachedRegion has been loaded from disk } // NewAuthProvider creates an AuthProvider from configuration. -// If ClientID and ClientSecret are set, uses JWT auth with the specified endpoint. +// If ClientID and ClientSecret are set, uses JWT auth with the specified base URL. // Otherwise falls back to legacy Basic auth with Token. func NewAuthProvider(config AuthConfig) (*AuthProvider, error) { p := &AuthProvider{ @@ -64,12 +70,12 @@ func NewAuthProvider(config AuthConfig) (*AuthProvider, error) { // Determine auth mode: JWT credentials take priority if config.ClientID != "" && config.ClientSecret != "" { - // JWT auth + // JWT auth via moose p.isLegacy = false - if config.AuthEndpoint == "" { - return nil, fmt.Errorf("--auth-endpoint is required when using client credentials") + if config.BaseURL == "" { + return nil, fmt.Errorf("base URL is required for JWT authentication") } - authClient, err := NewAuthClient(config.AuthEndpoint, config.Debug) + authClient, err := NewAuthClient(config.BaseURL, config.Debug) if err != nil { return nil, fmt.Errorf("failed to create auth client: %w", err) } @@ -86,14 +92,14 @@ func NewAuthProvider(config AuthConfig) (*AuthProvider, error) { return nil, fmt.Errorf("tenant ID required: use --tenant-id flag or ARMIS_TENANT_ID environment variable") } } else { - return nil, fmt.Errorf("authentication required: use --token flag or ARMIS_API_TOKEN environment variable") + return nil, fmt.Errorf("authentication required: set ARMIS_CLIENT_ID and ARMIS_CLIENT_SECRET for JWT auth, or ARMIS_API_TOKEN for legacy auth") } return p, nil } // GetAuthorizationHeader returns the Authorization header value. -// For JWT auth: the raw JWT token (no "Bearer" prefix - backend expects raw JWT) +// For JWT auth: "Bearer " per RFC 6750 // For Basic auth: "Basic " per RFC 7617 func (p *AuthProvider) GetAuthorizationHeader(ctx context.Context) (string, error) { if p.isLegacy { @@ -108,8 +114,8 @@ func (p *AuthProvider) GetAuthorizationHeader(ctx context.Context) (string, erro p.mu.RLock() defer p.mu.RUnlock() - // Raw JWT token (no Bearer prefix) - backend expects raw JWT per API contract - return p.credentials.Token, nil + // Bearer token per RFC 6750 + return "Bearer " + p.credentials.Token, nil } // GetTenantID returns the tenant ID for API requests. @@ -129,6 +135,23 @@ func (p *AuthProvider) GetTenantID(ctx context.Context) (string, error) { return p.credentials.TenantID, nil } +// GetRegion returns the deployment region from the JWT token. +// For JWT auth: extracted from region claim (may be empty for older tokens) +// For Basic auth: returns empty string (no region available) +func (p *AuthProvider) GetRegion(ctx context.Context) (string, error) { + if p.isLegacy { + return "", nil // Legacy auth doesn't have region + } + + if err := p.refreshIfNeeded(ctx); err != nil { + return "", fmt.Errorf("failed to refresh token: %w", err) + } + + p.mu.RLock() + defer p.mu.RUnlock() + return p.credentials.Region, nil +} + // IsLegacy returns true if using legacy Basic auth. func (p *AuthProvider) IsLegacy() bool { return p.isLegacy @@ -160,6 +183,16 @@ func (p *AuthProvider) GetRawToken(ctx context.Context) (string, error) { // exchangeCredentials exchanges client credentials for a JWT token. // Uses double-checked locking to prevent thundering herd of concurrent refreshes. +// Leverages region caching to avoid auto-discovery overhead on subsequent requests. +// +// Region selection priority: +// 1. --region flag (config.Region) - explicit override, bypasses cache and discovery +// 2. Cached region - from previous successful auth for this client_id +// 3. Auto-discovery - server tries regions until one succeeds +// +// Retry behavior: If auth fails with a cached region hint (not explicit --region), +// the cache is cleared and auth is retried without the hint. This handles stale +// cache gracefully without requiring user to re-run the command. func (p *AuthProvider) exchangeCredentials(ctx context.Context) error { p.mu.Lock() defer p.mu.Unlock() @@ -169,21 +202,62 @@ func (p *AuthProvider) exchangeCredentials(ctx context.Context) error { return nil } - token, err := p.authClient.Authenticate(ctx, p.config.ClientID, p.config.ClientSecret) + // Load cached region once per process (memoize to avoid repeated disk I/O) + if !p.regionLoaded { + if region, ok := loadCachedRegion(p.config.ClientID); ok { + p.cachedRegion = region + } + p.regionLoaded = true + } + + // Determine region hint - explicit flag takes priority over cache + var regionHint *string + var usingCachedHint bool + if p.config.Region != "" { + // Explicit --region flag - don't retry on failure (user error) + regionHint = &p.config.Region + } else if p.cachedRegion != "" { + // Cached region - will retry without hint on failure + regionHint = &p.cachedRegion + usingCachedHint = true + } + + result, err := p.authClient.Authenticate(ctx, p.config.ClientID, p.config.ClientSecret, regionHint) if err != nil { - return err + // If auth failed with a cached region hint, retry only for region-specific rejections. + // Skip retry for: transport errors (not *AuthError), 401 (bad credentials). + // This avoids double requests on network failures and prevents wiping correct cache entries. + var authErr *AuthError + if usingCachedHint && errors.As(err, &authErr) && authErr.StatusCode != http.StatusUnauthorized { + clearCachedRegion() + p.cachedRegion = "" + // Retry without region hint - let server auto-discover + result, err = p.authClient.Authenticate(ctx, p.config.ClientID, p.config.ClientSecret, nil) + if err != nil { + return err + } + } else { + return err + } + } + + // Cache the discovered region for future requests (skip if unchanged) + if result.Region != "" && result.Region != p.cachedRegion { + saveCachedRegion(p.config.ClientID, result.Region) + p.cachedRegion = result.Region } // Parse JWT to extract claims - claims, err := parseJWTClaims(token) + claims, err := parseJWTClaims(result.Token) if err != nil { return fmt.Errorf("failed to parse JWT: %w", err) } p.credentials = &JWTCredentials{ - Token: token, + Token: result.Token, TenantID: claims.CustomerID, ExpiresAt: claims.ExpiresAt, + Region: claims.Region, } return nil @@ -207,6 +281,7 @@ func (p *AuthProvider) refreshIfNeeded(ctx context.Context) error { type jwtClaims struct { CustomerID string // maps to tenant_id ExpiresAt time.Time + Region string // deployment region (optional) } // parseJWTClaims extracts claims from a JWT without signature verification. @@ -233,7 +308,8 @@ func parseJWTClaims(token string) (*jwtClaims, error) { var data struct { CustomerID string `json:"customer_id"` - Exp float64 `json:"exp"` // float64 to handle servers that return fractional timestamps + Exp float64 `json:"exp"` // float64 to handle servers that return fractional timestamps + Region string `json:"region"` // optional deployment region } if err := json.Unmarshal(payload, &data); err != nil { return nil, fmt.Errorf("failed to parse JWT payload: %w", err) @@ -252,5 +328,6 @@ func parseJWTClaims(token string) (*jwtClaims, error) { return &jwtClaims{ CustomerID: data.CustomerID, ExpiresAt: time.Unix(expSec, 0), + Region: data.Region, // may be empty for backward compatibility }, nil } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index d8c419f..38e417d 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -49,6 +49,23 @@ func createMockJWTWithFloatExp(customerID string, exp float64) string { return header + "." + payload + "." + signature } +// createMockJWTWithRegion creates a mock JWT token with region claim. +func createMockJWTWithRegion(customerID string, exp int64, region string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + + claims := map[string]interface{}{ + "customer_id": customerID, + "exp": exp, + "region": region, + } + claimsJSON, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(claimsJSON) + + signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature")) + + return header + "." + payload + "." + signature +} + func TestNewAuthProvider_LegacyAuth(t *testing.T) { t.Run("succeeds with token and tenant ID", func(t *testing.T) { config := AuthConfig{ @@ -171,7 +188,7 @@ func TestNewAuthProvider_JWTAuth(t *testing.T) { mockJWT := createMockJWT("tenant-from-jwt", time.Now().Add(1*time.Hour).Unix()) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/api/v1/authenticate" { + if r.URL.Path != "/api/v1/auth/token" { t.Errorf("Unexpected path: %s", r.URL.Path) } if r.Method != "POST" { @@ -186,7 +203,7 @@ func TestNewAuthProvider_JWTAuth(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, } p, err := NewAuthProvider(config) @@ -202,9 +219,10 @@ func TestNewAuthProvider_JWTAuth(t *testing.T) { if err != nil { t.Fatalf("GetAuthorizationHeader failed: %v", err) } - // Raw JWT token (no Bearer prefix) - backend expects raw JWT per API contract - if header != mockJWT { - t.Errorf("Unexpected auth header: got %q, want %q", header, mockJWT) + // Bearer token per RFC 6750 + expectedHeader := "Bearer " + mockJWT + if header != expectedHeader { + t.Errorf("Unexpected auth header: got %q, want %q", header, expectedHeader) } tid, err := p.GetTenantID(context.Background()) @@ -234,7 +252,7 @@ func TestNewAuthProvider_JWTAuth(t *testing.T) { config := AuthConfig{ ClientID: "bad-client", ClientSecret: "bad-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, } _, err := NewAuthProvider(config) @@ -246,19 +264,19 @@ func TestNewAuthProvider_JWTAuth(t *testing.T) { } }) - t.Run("fails without auth endpoint", func(t *testing.T) { + t.Run("fails without base URL", func(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - // No AuthEndpoint + // No BaseURL } _, err := NewAuthProvider(config) if err == nil { - t.Error("Expected error for missing auth endpoint") + t.Error("Expected error for missing base URL") } - if !strings.Contains(err.Error(), "--auth-endpoint is required") { - t.Errorf("Expected auth endpoint required error, got: %v", err) + if !strings.Contains(err.Error(), "base URL is required") { + t.Errorf("Expected base URL required error, got: %v", err) } }) @@ -275,7 +293,7 @@ func TestNewAuthProvider_JWTAuth(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, Token: "basic-token", TenantID: "basic-tenant", } @@ -316,7 +334,7 @@ func TestAuthProvider_RefreshIfNeeded(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, } p, err := NewAuthProvider(config) @@ -353,7 +371,7 @@ func TestAuthProvider_NoRefreshWhenValid(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, } p, err := NewAuthProvider(config) @@ -390,7 +408,7 @@ func TestAuthProvider_ThreadSafe(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, } p, err := NewAuthProvider(config) @@ -515,6 +533,113 @@ func TestParseJWTClaims(t *testing.T) { t.Error("Expected error for invalid base64") } }) + + t.Run("valid JWT with region", func(t *testing.T) { + token := createMockJWTWithRegion("cust-789", 1700000000, "us1") + + result, err := parseJWTClaims(token) + if err != nil { + t.Fatalf("parseJWTClaims failed: %v", err) + } + + if result.CustomerID != "cust-789" { + t.Errorf("Expected customer_id 'cust-789', got %q", result.CustomerID) + } + if result.Region != "us1" { + t.Errorf("Expected region 'us1', got %q", result.Region) + } + }) + + t.Run("valid JWT without region (backward compatibility)", func(t *testing.T) { + // Tokens without region should still parse successfully + token := createMockJWT("cust-legacy", 1700000000) + + result, err := parseJWTClaims(token) + if err != nil { + t.Fatalf("parseJWTClaims failed: %v", err) + } + + if result.CustomerID != "cust-legacy" { + t.Errorf("Expected customer_id 'cust-legacy', got %q", result.CustomerID) + } + if result.Region != "" { + t.Errorf("Expected empty region for legacy token, got %q", result.Region) + } + }) +} + +func TestGetRegion(t *testing.T) { + t.Run("returns empty string for legacy Basic auth", func(t *testing.T) { + p, err := NewAuthProvider(AuthConfig{ + Token: "test-token", + TenantID: "tenant-123", + }) + if err != nil { + t.Fatalf("NewAuthProvider failed: %v", err) + } + + region, err := p.GetRegion(context.Background()) + if err != nil { + t.Fatalf("GetRegion failed: %v", err) + } + if region != "" { + t.Errorf("Expected empty region for Basic auth, got %q", region) + } + }) + + t.Run("returns region for JWT auth", func(t *testing.T) { + mockJWT := createMockJWTWithRegion("tenant-123", time.Now().Add(1*time.Hour).Unix(), "eu1") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"token": mockJWT}) + })) + defer server.Close() + + p, err := NewAuthProvider(AuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("NewAuthProvider failed: %v", err) + } + + region, err := p.GetRegion(context.Background()) + if err != nil { + t.Fatalf("GetRegion failed: %v", err) + } + if region != "eu1" { + t.Errorf("Expected region 'eu1', got %q", region) + } + }) + + t.Run("returns empty region for JWT without region claim", func(t *testing.T) { + mockJWT := createMockJWT("tenant-123", time.Now().Add(1*time.Hour).Unix()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"token": mockJWT}) + })) + defer server.Close() + + p, err := NewAuthProvider(AuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("NewAuthProvider failed: %v", err) + } + + region, err := p.GetRegion(context.Background()) + if err != nil { + t.Fatalf("GetRegion failed: %v", err) + } + if region != "" { + t.Errorf("Expected empty region for token without region, got %q", region) + } + }) } func TestIsLegacy(t *testing.T) { @@ -574,13 +699,13 @@ func TestNewAuthClient(t *testing.T) { } }) - t.Run("fails with empty endpoint", func(t *testing.T) { + t.Run("fails with empty base URL", func(t *testing.T) { _, err := NewAuthClient("", false) if err == nil { - t.Error("Expected error for empty endpoint") + t.Error("Expected error for empty base URL") } - if !strings.Contains(err.Error(), "auth endpoint is required") { - t.Errorf("Expected endpoint required error, got: %v", err) + if !strings.Contains(err.Error(), "API base URL is required") { + t.Errorf("Expected base URL required error, got: %v", err) } }) @@ -589,7 +714,7 @@ func TestNewAuthClient(t *testing.T) { if err == nil { t.Error("Expected error for invalid URL") } - if !strings.Contains(err.Error(), "invalid endpoint URL") { + if !strings.Contains(err.Error(), "invalid base URL") { t.Errorf("Expected invalid URL error, got: %v", err) } }) @@ -615,7 +740,7 @@ func TestAuthProvider_DoubleCheckedLocking(t *testing.T) { config := AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: server.URL, + BaseURL: server.URL, } p, err := NewAuthProvider(config) @@ -677,7 +802,7 @@ func TestAuthProvider_ContextCancellation(t *testing.T) { config: AuthConfig{ ClientID: "test-client", ClientSecret: "test-secret", - AuthEndpoint: slowServer.URL, + BaseURL: slowServer.URL, }, authClient: authClient, credentials: &JWTCredentials{ @@ -702,3 +827,89 @@ func TestAuthProvider_ContextCancellation(t *testing.T) { t.Logf("Got error: %v (context cancellation may manifest differently)", err) } } + +func TestAuthProvider_CachedRegionRetryOnFailure(t *testing.T) { + // Test that when auth fails with a cached region hint, + // the provider clears the cache and retries without the hint. + // This handles stale cache (region changed) without requiring user to re-run. + + // Use a temporary cache directory to avoid affecting real cache + origCache := defaultCache + tempDir := t.TempDir() + defaultCache = &RegionCache{cacheDir: tempDir} + t.Cleanup(func() { + defaultCache = origCache + }) + + // Track requests to verify retry behavior + var requestCount int32 + staleRegion := "stale-region" + correctRegion := "us1" + + // Pre-populate cache with a stale region + saveCachedRegion("test-client", staleRegion) + + mockJWT := createMockJWTWithRegion("tenant-123", time.Now().Add(1*time.Hour).Unix(), correctRegion) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requestCount, 1) + + // Parse request body to check for region hint + var reqBody struct { + Region *string `json:"region"` + } + _ = json.NewDecoder(r.Body).Decode(&reqBody) + + // First request with stale region hint should fail with 403 (region rejected). + // Note: 401 means invalid credentials and should NOT trigger retry. + if reqBody.Region != nil && *reqBody.Region == staleRegion { + w.WriteHeader(http.StatusForbidden) + return + } + + // Second request without region hint (or with correct region) should succeed + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "token": mockJWT, + "region": correctRegion, + }) + })) + defer server.Close() + + config := AuthConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + BaseURL: server.URL, + } + + // Create auth provider - this triggers initial authentication + p, err := NewAuthProvider(config) + if err != nil { + t.Fatalf("NewAuthProvider failed: %v", err) + } + + // Verify we made 2 requests: first failed with hint, second succeeded without + finalCount := atomic.LoadInt32(&requestCount) + if finalCount != 2 { + t.Errorf("Expected 2 requests (failed with hint + retry without), got %d", finalCount) + } + + // Verify auth succeeded + header, err := p.GetAuthorizationHeader(context.Background()) + if err != nil { + t.Fatalf("GetAuthorizationHeader failed: %v", err) + } + expectedHeader := "Bearer " + mockJWT + if header != expectedHeader { + t.Errorf("Expected Bearer JWT token, got %q", header) + } + + // Verify cache was updated with correct region + cachedRegion, ok := loadCachedRegion("test-client") + if !ok { + t.Error("Expected region to be cached after successful auth") + } + if cachedRegion != correctRegion { + t.Errorf("Expected cached region %q, got %q", correctRegion, cachedRegion) + } +} diff --git a/internal/auth/client.go b/internal/auth/client.go index b0b5909..bfb2617 100644 --- a/internal/auth/client.go +++ b/internal/auth/client.go @@ -18,36 +18,46 @@ const ( maxResponseSize = 1 << 20 // 1MB ) +// AuthError represents an authentication failure with HTTP status context. +// This allows callers to distinguish between different failure modes +// (e.g., invalid credentials vs. region-specific rejection vs. server error). +type AuthError struct { + StatusCode int + Message string +} + +func (e *AuthError) Error() string { return e.Message } + // AuthClient handles authentication with an external auth service. type AuthClient struct { - endpoint string + baseURL string httpClient *http.Client debug bool } -// NewAuthClient creates a new authentication client for the given endpoint. -// The endpoint must be a valid HTTPS URL (HTTP allowed only for localhost). +// NewAuthClient creates a new authentication client for the given base URL. +// The base URL must be a valid HTTPS URL (HTTP allowed only for localhost). // If debug is true, authentication failures will log detailed error information. -func NewAuthClient(endpoint string, debug bool) (*AuthClient, error) { - if endpoint == "" { - return nil, fmt.Errorf("auth endpoint is required") +func NewAuthClient(baseURL string, debug bool) (*AuthClient, error) { + if baseURL == "" { + return nil, fmt.Errorf("API base URL is required for authentication") } - parsedURL, err := url.Parse(endpoint) + parsedURL, err := url.Parse(baseURL) if err != nil { - return nil, fmt.Errorf("invalid endpoint URL: %w", err) + return nil, fmt.Errorf("invalid base URL: %w", err) } // Require HTTPS for non-localhost if parsedURL.Scheme != "https" { host := parsedURL.Hostname() if host != "localhost" && host != "127.0.0.1" { - return nil, fmt.Errorf("HTTPS required for non-localhost endpoint") + return nil, fmt.Errorf("HTTPS required for non-localhost URLs") } } return &AuthClient{ - endpoint: strings.TrimSuffix(endpoint, "/"), + baseURL: strings.TrimSuffix(baseURL, "/"), httpClient: &http.Client{ Timeout: 30 * time.Second, }, @@ -57,51 +67,61 @@ func NewAuthClient(endpoint string, debug bool) (*AuthClient, error) { // authRequest is the request body for the authenticate endpoint. type authRequest struct { - ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` //nolint:gosec // G117: This is a JSON field name for auth request, not a secret value + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec // G117: This is a JSON field name for auth request, not a secret value + Region *string `json:"region,omitempty"` // Optional region hint from cache } // authResponse is the response from the authenticate endpoint. type authResponse struct { - Token string `json:"token"` - Error string `json:"error,omitempty"` + Token string `json:"token"` + Region string `json:"region,omitempty"` // Discovered region for caching + Error string `json:"error,omitempty"` +} + +// AuthResult contains the authentication response with token and discovered region. +type AuthResult struct { + Token string + Region string } // Authenticate exchanges client credentials for a JWT token. -// Calls POST /api/v1/authenticate with client_id and client_secret. -func (c *AuthClient) Authenticate(ctx context.Context, clientID, clientSecret string) (string, error) { +// Calls POST /api/v1/auth/token with client_id, client_secret, and optional region hint. +// Returns the token and the discovered/confirmed region for caching. +func (c *AuthClient) Authenticate(ctx context.Context, clientID, clientSecret string, regionHint *string) (*AuthResult, error) { reqBody := authRequest{ ClientID: clientID, ClientSecret: clientSecret, + Region: regionHint, } jsonBody, err := json.Marshal(reqBody) if err != nil { - return "", fmt.Errorf("failed to marshal request: %w", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) } - authEndpoint := c.endpoint + "/api/v1/authenticate" + authEndpoint := c.baseURL + "/api/v1/auth/token" req, err := http.NewRequestWithContext(ctx, "POST", authEndpoint, bytes.NewReader(jsonBody)) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(req) //nolint:gosec // G704: authEndpoint is constructed from validated config, not user input if err != nil { - return "", fmt.Errorf("authentication request failed: %w", err) + return nil, fmt.Errorf("authentication request failed: %w", err) } defer resp.Body.Close() //nolint:errcheck // response body read-only limitedReader := io.LimitReader(resp.Body, maxResponseSize) body, err := io.ReadAll(limitedReader) if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) + return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode == http.StatusUnauthorized { - return "", fmt.Errorf("invalid credentials") + return nil, &AuthError{StatusCode: resp.StatusCode, Message: "invalid credentials"} } if resp.StatusCode != http.StatusOK { @@ -110,22 +130,25 @@ func (c *AuthClient) Authenticate(ctx context.Context, clientID, clientSecret st fmt.Fprintf(os.Stderr, "DEBUG: Auth failed with status %d, body: %s\n", resp.StatusCode, string(body)) } // Don't include raw response body in error to prevent potential info leakage - return "", fmt.Errorf("authentication failed (status %d)", resp.StatusCode) + return nil, &AuthError{StatusCode: resp.StatusCode, Message: fmt.Sprintf("authentication failed (status %d)", resp.StatusCode)} } var authResp authResponse if err := json.Unmarshal(body, &authResp); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) + return nil, fmt.Errorf("failed to parse response: %w", err) } if authResp.Error != "" { // Don't include raw error content to prevent potential sensitive info leakage - return "", fmt.Errorf("authentication failed: server returned an error") + return nil, fmt.Errorf("authentication failed: server returned an error") } if authResp.Token == "" { - return "", fmt.Errorf("no token in response") + return nil, fmt.Errorf("no token in response") } - return authResp.Token, nil + return &AuthResult{ + Token: authResp.Token, + Region: authResp.Region, + }, nil } diff --git a/internal/auth/region_cache.go b/internal/auth/region_cache.go new file mode 100644 index 0000000..961eded --- /dev/null +++ b/internal/auth/region_cache.go @@ -0,0 +1,141 @@ +// Package auth provides authentication for the Armis API. +// This file handles region caching to avoid auto-discovery on subsequent authentications. +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + + "github.com/ArmisSecurity/armis-cli/internal/util" +) + +const ( + // regionCacheFileName is the name of the region cache file. + regionCacheFileName = "region-cache.json" + + // maxCacheFileSize limits region cache reads to prevent memory exhaustion + // from corrupted or maliciously large files. The actual cache is ~60 bytes. + maxCacheFileSize = 4096 // 4KB +) + +// regionCacheEntry is the on-disk JSON structure for persisting region. +type regionCacheEntry struct { + ClientID string `json:"client_id"` + Region string `json:"region"` +} + +// RegionCache handles region caching with optional directory override for testing. +type RegionCache struct { + cacheDir string // for testing; empty means use util.GetCacheDir() +} + +// NewRegionCache creates a region cache with default settings. +func NewRegionCache() *RegionCache { + return &RegionCache{} +} + +// Load attempts to load a cached region for the given client ID. +// Returns the region and true if found, empty string and false otherwise. +func (c *RegionCache) Load(clientID string) (string, bool) { + path := c.getFilePath() + if path == "" { + return "", false + } + + info, err := os.Stat(path) + if err != nil { + return "", false + } + if info.Size() > maxCacheFileSize { + return "", false + } + + data, err := os.ReadFile(path) //nolint:gosec // path validated by getFilePath + if err != nil { + return "", false + } + + var cache regionCacheEntry + if err := json.Unmarshal(data, &cache); err != nil { + return "", false + } + + // Only return if client ID matches (prevent cross-credential pollution) + if cache.ClientID != clientID { + return "", false + } + + return cache.Region, cache.Region != "" +} + +// Save persists the region for the given client ID. +// Errors are silently ignored (best-effort caching). +func (c *RegionCache) Save(clientID, region string) { + if clientID == "" || region == "" { + return + } + + path := c.getFilePath() + if path == "" { + return + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return + } + + cache := regionCacheEntry{ + ClientID: clientID, + Region: region, + } + + data, err := json.Marshal(cache) + if err != nil { + return + } + + _ = os.WriteFile(path, data, 0o600) //nolint:gosec // path validated by getFilePath +} + +// Clear removes the cached region. +// Used when authentication fails with a cached region hint. +func (c *RegionCache) Clear() { + path := c.getFilePath() + if path == "" { + return + } + + _ = os.Remove(path) //nolint:errcheck // best-effort cleanup +} + +// getFilePath returns the validated path to the region cache file. +func (c *RegionCache) getFilePath() string { + if c.cacheDir != "" { + // Testing override - validate the provided path + sanitized, err := util.SanitizePath(c.cacheDir) + if err != nil { + return "" + } + return filepath.Join(sanitized, regionCacheFileName) + } + return util.GetCacheFilePath(regionCacheFileName) +} + +// Package-level convenience functions using a default cache instance. +// These maintain backward compatibility with existing code. + +var defaultCache = NewRegionCache() + +func loadCachedRegion(clientID string) (string, bool) { + return defaultCache.Load(clientID) +} + +func saveCachedRegion(clientID, region string) { + defaultCache.Save(clientID, region) +} + +func clearCachedRegion() { + defaultCache.Clear() +} diff --git a/internal/auth/region_cache_test.go b/internal/auth/region_cache_test.go new file mode 100644 index 0000000..2aacd8c --- /dev/null +++ b/internal/auth/region_cache_test.go @@ -0,0 +1,300 @@ +package auth + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +const ( + testRegionUS1 = "us1" + testRegionEU1 = "eu1" + testRegionAP1 = "ap1" +) + +func TestRegionCache_RoundTrip(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Initially empty + region, ok := cache.Load("client-123") + if ok { + t.Error("Expected no cached region initially") + } + if region != "" { + t.Errorf("Expected empty region, got %q", region) + } + + // Save and load + cache.Save("client-123", testRegionUS1) + region, ok = cache.Load("client-123") + if !ok { + t.Error("Expected to find cached region after save") + } + if region != testRegionUS1 { + t.Errorf("Expected %q, got %q", testRegionUS1, region) + } +} + +func TestRegionCache_ClientIDMismatch(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Save for client-123 + cache.Save("client-123", testRegionUS1) + + // Try to load for different client - should NOT return the cached region + region, ok := cache.Load("client-456") + if ok { + t.Error("Expected cache miss for different client ID") + } + if region != "" { + t.Errorf("Expected empty region for mismatched client, got %q", region) + } + + // Original client should still work + region, ok = cache.Load("client-123") + if !ok { + t.Error("Expected cache hit for original client") + } + if region != testRegionUS1 { + t.Errorf("Expected %q, got %q", testRegionUS1, region) + } +} + +func TestRegionCache_Clear(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + cache.Save("client-123", testRegionUS1) + + // Verify it's there + _, ok := cache.Load("client-123") + if !ok { + t.Fatal("Expected region to be cached before clear") + } + + // Clear + cache.Clear() + + // Verify it's gone + region, ok := cache.Load("client-123") + if ok { + t.Error("Expected cache miss after clear") + } + if region != "" { + t.Errorf("Expected empty region after clear, got %q", region) + } +} + +func TestRegionCache_MissingFile(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Load from non-existent file should return empty gracefully + region, ok := cache.Load("client-123") + if ok { + t.Error("Expected cache miss for non-existent file") + } + if region != "" { + t.Errorf("Expected empty region, got %q", region) + } +} + +func TestRegionCache_CorruptJSON(t *testing.T) { + tempDir := t.TempDir() + cache := &RegionCache{cacheDir: tempDir} + + // Write corrupt JSON + cachePath := filepath.Join(tempDir, regionCacheFileName) + if err := os.WriteFile(cachePath, []byte("not valid json{"), 0o600); err != nil { + t.Fatalf("Failed to write corrupt file: %v", err) + } + + // Load should handle corrupt JSON gracefully + region, ok := cache.Load("client-123") + if ok { + t.Error("Expected cache miss for corrupt JSON") + } + if region != "" { + t.Errorf("Expected empty region for corrupt JSON, got %q", region) + } +} + +func TestRegionCache_EmptyRegion(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Save with empty region should be a no-op + cache.Save("client-123", "") + + // Should still be empty + region, ok := cache.Load("client-123") + if ok { + t.Error("Expected cache miss after saving empty region") + } + if region != "" { + t.Errorf("Expected empty region, got %q", region) + } +} + +func TestRegionCache_EmptyClientID(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Save with empty client ID should be a no-op + cache.Save("", testRegionUS1) + + // Should still be empty + region, ok := cache.Load("") + if ok { + t.Error("Expected cache miss after saving empty client ID") + } + if region != "" { + t.Errorf("Expected empty region, got %q", region) + } +} + +func TestRegionCache_Overwrite(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Save initial region + cache.Save("client-123", testRegionUS1) + + // Overwrite with new region + cache.Save("client-123", testRegionEU1) + + // Should return the new region + region, ok := cache.Load("client-123") + if !ok { + t.Error("Expected cache hit after overwrite") + } + if region != testRegionEU1 { + t.Errorf("Expected %q after overwrite, got %q", testRegionEU1, region) + } +} + +func TestRegionCache_FilePermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix file permissions not supported on Windows") + } + + tempDir := t.TempDir() + cache := &RegionCache{cacheDir: tempDir} + + cache.Save("client-123", testRegionUS1) + + // Verify file permissions are restrictive (0600) + cachePath := filepath.Join(tempDir, regionCacheFileName) + info, err := os.Stat(cachePath) + if err != nil { + t.Fatalf("Failed to stat cache file: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0o600 { + t.Errorf("Expected file permissions 0600, got %o", perm) + } +} + +func TestRegionCache_ClearNonExistent(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + // Clear should not panic when file doesn't exist + cache.Clear() // Should be a no-op, not panic +} + +func TestRegionCache_TableDriven(t *testing.T) { + tests := []struct { + name string + saveClient string + saveRegion string + loadClient string + wantRegion string + wantFound bool + }{ + { + name: "exact match", + saveClient: "client-a", + saveRegion: testRegionUS1, + loadClient: "client-a", + wantRegion: testRegionUS1, + wantFound: true, + }, + { + name: "client mismatch", + saveClient: "client-a", + saveRegion: testRegionUS1, + loadClient: "client-b", + wantRegion: "", + wantFound: false, + }, + { + name: "case sensitive client ID", + saveClient: "Client-A", + saveRegion: testRegionUS1, + loadClient: "client-a", + wantRegion: "", + wantFound: false, + }, + { + name: "special characters in client ID", + saveClient: "client@123.example.com", + saveRegion: testRegionAP1, + loadClient: "client@123.example.com", + wantRegion: testRegionAP1, + wantFound: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := &RegionCache{cacheDir: t.TempDir()} + + if tt.saveClient != "" && tt.saveRegion != "" { + cache.Save(tt.saveClient, tt.saveRegion) + } + + region, found := cache.Load(tt.loadClient) + if found != tt.wantFound { + t.Errorf("Load(%q) found = %v, want %v", tt.loadClient, found, tt.wantFound) + } + if region != tt.wantRegion { + t.Errorf("Load(%q) region = %q, want %q", tt.loadClient, region, tt.wantRegion) + } + }) + } +} + +// TestPackageLevelFunctions verifies the backward-compatible package functions work. +// Note: These use the global defaultCache, which shares state across tests if run in parallel. +// We test them in isolation by temporarily replacing defaultCache with a test instance. +func TestPackageLevelFunctions(t *testing.T) { + // Use a temporary cache directory so we don't modify the user's real cache + origCache := defaultCache + defaultCache = &RegionCache{cacheDir: t.TempDir()} + t.Cleanup(func() { + defaultCache = origCache + }) + + // Clear any existing cache in the temporary directory + clearCachedRegion() + + // Load from empty should return false + region, ok := loadCachedRegion("test-client-pkg-level") + if ok { + t.Errorf("Expected cache miss from empty cache, got region=%q", region) + } + + // Save and load round-trip + saveCachedRegion("test-client-pkg-level", testRegionUS1) + region, ok = loadCachedRegion("test-client-pkg-level") + if !ok { + t.Error("Expected cache hit after save") + } + if region != testRegionUS1 { + t.Errorf("Expected region %q, got %q", testRegionUS1, region) + } + + // Clear should remove the entry + clearCachedRegion() + _, ok = loadCachedRegion("test-client-pkg-level") + if ok { + t.Error("Expected cache miss after clear") + } +} diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index 6d8186b..6b5c47c 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -18,16 +18,14 @@ This command is useful for: - Obtaining tokens for use with other tools - Debugging JWT-related issues -Requires --client-id, --client-secret, and --auth-endpoint flags or -their corresponding environment variables (ARMIS_CLIENT_ID, -ARMIS_CLIENT_SECRET, ARMIS_AUTH_ENDPOINT).`, +Requires --client-id and --client-secret flags or their corresponding +environment variables (ARMIS_CLIENT_ID, ARMIS_CLIENT_SECRET).`, Example: ` # Obtain JWT token using flags - armis-cli auth --client-id MY_ID --client-secret MY_SECRET --auth-endpoint https://auth.example.com + armis-cli auth --client-id MY_ID --client-secret MY_SECRET # Obtain token using environment variables export ARMIS_CLIENT_ID=MY_ID export ARMIS_CLIENT_SECRET=MY_SECRET - export ARMIS_AUTH_ENDPOINT=https://auth.example.com armis-cli auth`, RunE: runAuth, } @@ -46,9 +44,6 @@ func runAuth(cmd *cobra.Command, args []string) error { if clientSecret == "" { return fmt.Errorf("--client-secret is required (or set ARMIS_CLIENT_SECRET)") } - if authEndpoint == "" { - return fmt.Errorf("--auth-endpoint is required (or set ARMIS_AUTH_ENDPOINT)") - } provider, err := getAuthProvider() if err != nil { diff --git a/internal/cmd/auth_test.go b/internal/cmd/auth_test.go index 70c3982..d10edba 100644 --- a/internal/cmd/auth_test.go +++ b/internal/cmd/auth_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" "strings" "testing" "time" @@ -55,14 +56,6 @@ func TestRunAuth(t *testing.T) { wantErr: true, errContains: "--client-secret is required", }, - { - name: "missing auth-endpoint", - clientID: "test-client", - clientSecret: "test-secret", - setupServer: false, // No server = empty authEndpoint - wantErr: true, - errContains: "--auth-endpoint is required", - }, { name: "successful authentication", clientID: "test-client", @@ -87,16 +80,20 @@ func TestRunAuth(t *testing.T) { // Save and restore original values (t.Cleanup runs even on panic) origClientID := clientID origClientSecret := clientSecret - origAuthEndpoint := authEndpoint origToken := token origTenantID := tenantID + origAPIURL := os.Getenv("ARMIS_API_URL") t.Cleanup(func() { clientID = origClientID clientSecret = origClientSecret - authEndpoint = origAuthEndpoint token = origToken tenantID = origTenantID + if origAPIURL == "" { + _ = os.Unsetenv("ARMIS_API_URL") + } else { + _ = os.Setenv("ARMIS_API_URL", origAPIURL) + } }) // Clear legacy auth vars to ensure JWT path is taken @@ -111,7 +108,7 @@ func TestRunAuth(t *testing.T) { // Create mock auth server mockJWT := createMockJWT("customer-123", time.Now().Add(time.Hour).Unix()) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/api/v1/authenticate" { + if r.URL.Path != "/api/v1/auth/token" { w.WriteHeader(http.StatusNotFound) return } @@ -131,9 +128,7 @@ func TestRunAuth(t *testing.T) { _ = json.NewEncoder(w).Encode(resp) })) defer server.Close() - authEndpoint = server.URL - } else { - authEndpoint = "" + _ = os.Setenv("ARMIS_API_URL", server.URL) } // Create a minimal cobra command with context @@ -165,16 +160,20 @@ func TestRunAuth_InvalidEndpoint(t *testing.T) { // Save and restore original values (t.Cleanup runs even on panic) origClientID := clientID origClientSecret := clientSecret - origAuthEndpoint := authEndpoint origToken := token origTenantID := tenantID + origAPIURL := os.Getenv("ARMIS_API_URL") t.Cleanup(func() { clientID = origClientID clientSecret = origClientSecret - authEndpoint = origAuthEndpoint token = origToken tenantID = origTenantID + if origAPIURL == "" { + _ = os.Unsetenv("ARMIS_API_URL") + } else { + _ = os.Setenv("ARMIS_API_URL", origAPIURL) + } }) // Clear legacy auth vars @@ -184,7 +183,7 @@ func TestRunAuth_InvalidEndpoint(t *testing.T) { // Set valid credentials but invalid endpoint clientID = "test-client" clientSecret = "test-secret" - authEndpoint = "http://localhost:99999" // Invalid port + _ = os.Setenv("ARMIS_API_URL", "http://localhost:99999") // Invalid port cmd := &cobra.Command{} cmd.SetContext(context.Background()) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 8d1c0ee..6cff988 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -48,7 +48,7 @@ var ( // JWT authentication clientID string clientSecret string - authEndpoint string + region string version = versionDev commit = "none" @@ -114,6 +114,13 @@ var rootCmd = &cobra.Command{ output.SyncColors() + // Warn if the removed ARMIS_AUTH_ENDPOINT env var is set + if os.Getenv("ARMIS_AUTH_ENDPOINT") != "" { + cli.PrintWarning("ARMIS_AUTH_ENDPOINT is no longer supported. " + + "The auth endpoint is now derived from the base URL. " + + "Use ARMIS_API_URL to override the base URL, or --region to specify a region.") + } + // Skip update check if: // - explicitly disabled via flag or env var // - running in CI @@ -163,7 +170,7 @@ func init() { // JWT authentication rootCmd.PersistentFlags().StringVar(&clientID, "client-id", os.Getenv("ARMIS_CLIENT_ID"), "Client ID for JWT authentication (env: ARMIS_CLIENT_ID)") rootCmd.PersistentFlags().StringVar(&clientSecret, "client-secret", os.Getenv("ARMIS_CLIENT_SECRET"), "Client secret for JWT authentication (env: ARMIS_CLIENT_SECRET)") - rootCmd.PersistentFlags().StringVar(&authEndpoint, "auth-endpoint", os.Getenv("ARMIS_AUTH_ENDPOINT"), "Authentication service endpoint URL (env: ARMIS_AUTH_ENDPOINT)") + rootCmd.PersistentFlags().StringVar(®ion, "region", os.Getenv("ARMIS_REGION"), "Override region for authentication (bypasses auto-discovery) (env: ARMIS_REGION)") // General options rootCmd.PersistentFlags().BoolVar(&useDev, "dev", false, "Use development environment instead of production") @@ -266,12 +273,13 @@ func getAPIBaseURL() string { } // getAuthProvider creates an AuthProvider based on the provided credentials. -// Priority: JWT auth (--client-id, --client-secret, --auth-endpoint) > Basic auth (--token) +// Priority: JWT auth (--client-id, --client-secret) > Basic auth (--token) func getAuthProvider() (*auth.AuthProvider, error) { return auth.NewAuthProvider(auth.AuthConfig{ ClientID: clientID, ClientSecret: clientSecret, - AuthEndpoint: authEndpoint, + BaseURL: getAPIBaseURL(), + Region: region, Token: token, TenantID: tenantID, Debug: debug, diff --git a/internal/cmd/root_test.go b/internal/cmd/root_test.go index 50c7f8d..e24c5d1 100644 --- a/internal/cmd/root_test.go +++ b/internal/cmd/root_test.go @@ -739,14 +739,12 @@ func TestGetAuthProvider_NoCredentials(t *testing.T) { // Save original values originalClientID := clientID originalClientSecret := clientSecret - originalAuthEndpoint := authEndpoint originalToken := token originalTenantID := tenantID t.Cleanup(func() { clientID = originalClientID clientSecret = originalClientSecret - authEndpoint = originalAuthEndpoint token = originalToken tenantID = originalTenantID }) @@ -754,7 +752,6 @@ func TestGetAuthProvider_NoCredentials(t *testing.T) { // Clear all auth credentials clientID = "" clientSecret = "" - authEndpoint = "" token = "" tenantID = "" diff --git a/internal/update/update.go b/internal/update/update.go index 7f00cf8..ee90cb8 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -30,9 +30,6 @@ const ( // cacheFileName is the name of the cache file. cacheFileName = "update-check.json" - - // cacheDirName is the subdirectory name for cache files. - cacheDirName = "armis-cli" ) // CheckResult holds the result of a version check. @@ -195,18 +192,15 @@ func (c *Checker) fetchLatestVersion(ctx context.Context) (string, error) { // getCacheFilePath returns the path to the cache file. func (c *Checker) getCacheFilePath() string { if c.cacheDir != "" { - // Validate cacheDir to prevent path traversal (CWE-73) + // Testing override - validate the provided path sanitized, err := util.SanitizePath(c.cacheDir) if err != nil { return "" // invalid cacheDir, disable caching } return filepath.Join(sanitized, cacheFileName) } - cacheDir, err := os.UserCacheDir() - if err != nil { - return "" // no caching possible - } - return filepath.Join(cacheDir, cacheDirName, cacheFileName) + // Use shared utility for default cache path + return util.GetCacheFilePath(cacheFileName) } // readCache attempts to read a cached check result. diff --git a/internal/util/cache.go b/internal/util/cache.go new file mode 100644 index 0000000..5e31a4e --- /dev/null +++ b/internal/util/cache.go @@ -0,0 +1,83 @@ +// Package util provides shared utilities for the CLI. +package util + +import ( + "os" + "path/filepath" + "strings" +) + +const ( + // CacheDirName is the subdirectory name for CLI cache files. + // Used by both update checker and region cache. + CacheDirName = "armis-cli" +) + +// GetCacheDir returns the validated path to the CLI's cache directory. +// Returns empty string if the cache directory cannot be determined or validated. +// The directory is NOT created by this function - callers should create it if needed. +// +// Default location: ~/.cache/armis-cli (or platform equivalent) +func GetCacheDir() string { + userCacheDir, err := os.UserCacheDir() + if err != nil { + return "" + } + + cacheDir := filepath.Join(userCacheDir, CacheDirName) + + // Validate path to prevent traversal attacks (CWE-73) + sanitized, err := SanitizePath(cacheDir) + if err != nil { + return "" + } + + return sanitized +} + +// GetCacheFilePath returns the validated path to a cache file. +// Returns empty string if the path cannot be determined or validated. +// The filename must be a simple filename (no path separators or absolute paths). +func GetCacheFilePath(filename string) string { + cacheDir := GetCacheDir() + if cacheDir == "" { + return "" + } + + // Reject empty, whitespace-only, ".", and ".." filenames + // These would result in returning cacheDir itself, not a file path + filename = strings.TrimSpace(filename) + if filename == "" || filename == "." || filename == ".." { + return "" + } + + // Reject absolute paths - filepath.Join would discard cacheDir (CWE-22) + if filepath.IsAbs(filename) { + return "" + } + + // Reject path separators - filename should be a simple name like "cache.json" + if strings.ContainsAny(filename, `/\`) { + return "" + } + + filePath := filepath.Join(cacheDir, filename) + + // Re-validate the full path (filename could contain traversal attempts) + sanitized, err := SanitizePath(filePath) + if err != nil { + return "" + } + + // Final containment check: ensure result is within cache directory using robust path-based check + // Using filepath.Rel is more robust than strings.HasPrefix (handles case-insensitivity, path separators) + rel, err := filepath.Rel(cacheDir, sanitized) + if err != nil { + return "" + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "" + } + + return sanitized +} diff --git a/internal/util/cache_test.go b/internal/util/cache_test.go new file mode 100644 index 0000000..cfe60b7 --- /dev/null +++ b/internal/util/cache_test.go @@ -0,0 +1,157 @@ +package util + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestGetCacheDir(t *testing.T) { + dir := GetCacheDir() + + // Should not be empty (unless running in a very unusual environment) + if dir == "" { + t.Skip("Unable to determine cache directory (may be running in unusual environment)") + } + + // Should end with our cache dir name + if !strings.HasSuffix(dir, CacheDirName) { + t.Errorf("Expected cache dir to end with %q, got %q", CacheDirName, dir) + } + + // Should be an absolute path + if !filepath.IsAbs(dir) { + t.Errorf("Expected absolute path, got %q", dir) + } +} + +func TestGetCacheFilePath(t *testing.T) { + tests := []struct { + name string + filename string + wantEnd string + }{ + { + name: "simple filename", + filename: "test.json", + wantEnd: filepath.Join(CacheDirName, "test.json"), + }, + { + name: "another filename", + filename: "region-cache.json", + wantEnd: filepath.Join(CacheDirName, "region-cache.json"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := GetCacheFilePath(tt.filename) + if path == "" { + t.Skip("Unable to determine cache file path") + } + + if !strings.HasSuffix(path, tt.wantEnd) { + t.Errorf("GetCacheFilePath(%q) = %q, want suffix %q", tt.filename, path, tt.wantEnd) + } + + if !filepath.IsAbs(path) { + t.Errorf("Expected absolute path, got %q", path) + } + }) + } +} + +func TestGetCacheFilePath_SafeFilenames(t *testing.T) { + // This function is designed to be called with safe, constant filenames. + // The security boundary is enforced by rejecting absolute paths and + // path separators in the filename parameter. + + tests := []struct { + name string + filename string + wantContains string + wantEmpty bool + }{ + {"simple json file", "test.json", "test.json", false}, + {"hyphenated name", "region-cache.json", "region-cache.json", false}, + // Empty and special filenames are rejected (would return directory path) + {"empty filename rejected", "", "", true}, + {"whitespace only rejected", " ", "", true}, + {"dot rejected", ".", "", true}, + {"double dot rejected", "..", "", true}, + // Absolute paths are rejected (CWE-22: filepath.Join would discard cacheDir) + {"absolute path rejected", "/etc/passwd", "", true}, + // Path separators are rejected to ensure filename is a simple name + {"path with forward slash rejected", "foo/bar.json", "", true}, + {"path with backslash rejected", "foo\\bar.json", "", true}, + // Traversal attempts are rejected by SanitizePath + {"traversal rejected", "..\\secret.txt", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := GetCacheFilePath(tt.filename) + + if tt.wantEmpty { + if path != "" { + t.Errorf("GetCacheFilePath(%q) = %q, want empty (rejected)", tt.filename, path) + } + return + } + + if path == "" { + t.Errorf("GetCacheFilePath(%q) returned empty, want non-empty", tt.filename) + return + } + if !strings.Contains(path, tt.wantContains) { + t.Errorf("GetCacheFilePath(%q) = %q, want to contain %q", tt.filename, path, tt.wantContains) + } + // Verify result is within cache directory using a path-safe check + cacheDir := GetCacheDir() + rel, err := filepath.Rel(cacheDir, path) + if err != nil { + t.Fatalf("filepath.Rel(%q, %q) error: %v", cacheDir, path, err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + t.Errorf("GetCacheFilePath(%q) = %q, escapes cache dir %q (rel=%q)", tt.filename, path, cacheDir, rel) + } + }) + } +} + +func TestCacheDirName_Constant(t *testing.T) { + // Verify the constant matches expected value + if CacheDirName != "armis-cli" { + t.Errorf("CacheDirName = %q, want %q", CacheDirName, "armis-cli") + } +} + +func TestGetCacheDir_Idempotent(t *testing.T) { + // Multiple calls should return the same path + dir1 := GetCacheDir() + dir2 := GetCacheDir() + + if dir1 != dir2 { + t.Errorf("GetCacheDir() not idempotent: %q != %q", dir1, dir2) + } +} + +func TestGetCacheDir_UserCacheDir(t *testing.T) { + // Verify our cache dir is under the user's cache directory + userCacheDir, err := os.UserCacheDir() + if err != nil { + t.Skip("os.UserCacheDir() not available") + } + + dir := GetCacheDir() + if dir == "" { + t.Skip("GetCacheDir() returned empty") + } + + expected := filepath.Join(userCacheDir, CacheDirName) + // Compare cleaned paths + if filepath.Clean(dir) != filepath.Clean(expected) { + t.Errorf("GetCacheDir() = %q, want %q", dir, expected) + } +}