diff --git a/cmd/authk/root.go b/cmd/authk/root.go index 7e76219..9bde539 100644 --- a/cmd/authk/root.go +++ b/cmd/authk/root.go @@ -86,16 +86,18 @@ updating a .env file with the valid token.`, // Maintenance Loop for { - // Calculate sleep time (expires_in - 60s buffer) - sleepDuration := time.Duration(token.ExpiresIn-60) * time.Second - if sleepDuration < 10*time.Second { + // Calculate sleep time based on token expiry and a refresh buffer + refreshBuffer := 60 * time.Second // Refresh 60 seconds before expiry + sleepDuration := time.Until(token.Expiry) - refreshBuffer + if sleepDuration < 10*time.Second { // Ensure at least 10 seconds sleep sleepDuration = 10 * time.Second } log.Info().Dur("sleep_duration", sleepDuration).Msg("Waiting for token refresh") time.Sleep(sleepDuration) - newToken, err := client.RefreshToken(token.RefreshToken) + // Attempt to refresh the token + newToken, err := client.RefreshToken(token) if err != nil { log.Error().Err(err).Msg("Failed to refresh token, attempting full re-authentication") @@ -107,7 +109,9 @@ updating a .env file with the valid token.`, time.Sleep(10 * time.Second) // Force short sleep on next iteration to retry quickly - token.ExpiresIn = 0 + // By setting token.Expiry to now, time.Until will be negative, + // and sleepDuration will become 10s. + token.Expiry = time.Now() continue } } diff --git a/go.mod b/go.mod index 2579773..9cfffc7 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,17 @@ go 1.25.1 require ( cuelang.org/go v0.15.1 + github.com/coreos/go-oidc/v3 v3.17.0 github.com/fatih/color v1.18.0 github.com/rs/zerolog v1.34.0 github.com/spf13/cobra v1.10.1 + golang.org/x/oauth2 v0.33.0 ) require ( github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/emicklei/proto v1.14.2 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index eff79e8..1c23953 100644 --- a/go.sum +++ b/go.sum @@ -4,12 +4,16 @@ cuelang.org/go v0.15.1 h1:MRnjc/KJE+K42rnJ3a+425f1jqXeOOgq9SK4tYRTtWw= cuelang.org/go v0.15.1/go.mod h1:NYw6n4akZcTjA7QQwJ1/gqWrrhsN4aZwhcAL0jv9rZE= github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I= github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -61,8 +65,8 @@ golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= -golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= -golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= +golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/env/env_test.go b/internal/env/env_test.go index 6022066..bad2066 100644 --- a/internal/env/env_test.go +++ b/internal/env/env_test.go @@ -149,3 +149,32 @@ func TestFind(t *testing.T) { t.Errorf("Find() = %s, want %s", found, envFile) } } + +func TestFind_NotFound(t *testing.T) { + tmpDir := t.TempDir() + wd, _ := os.Getwd() + defer func() { _ = os.Chdir(wd) }() + if err := os.Chdir(tmpDir); err != nil { + t.Fatal(err) + } + + // Ensure the file doesn't exist in tmpDir or parents (unlikely but possible if running in root) + // And strictly speaking, we should ensure it's not in home dir either for the fallback. + // But a random name is safe enough. + _, err := Find("non_existent_file_random_12345") + if err == nil { + t.Error("Find() expected error for non-existent file, got nil") + } +} + +func TestFind_WithSeparator(t *testing.T) { + // On Linux, this is an absolute path + path := "/tmp/foo/.env" + found, err := Find(path) + if err != nil { + t.Fatal(err) + } + if found != path { + t.Errorf("Find() = %s, want %s", found, path) + } +} diff --git a/internal/oidc/client.go b/internal/oidc/client.go index 611bb00..38f4c99 100644 --- a/internal/oidc/client.go +++ b/internal/oidc/client.go @@ -1,82 +1,68 @@ package oidc import ( - "encoding/json" + "context" "fmt" "net/http" - "net/url" - "strings" "time" "github.com/codozor/authk/internal/config" "github.com/rs/zerolog/log" -) -type TokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - IDToken string `json:"id_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` -} + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) type Client struct { - cfg *config.Config - httpClient *http.Client - endpoints *providerEndpoints -} - -type providerEndpoints struct { - TokenEndpoint string `json:"token_endpoint"` + cfg *config.Config + provider *oidc.Provider + oauth2Config *oauth2.Config } func NewClient(cfg *config.Config) (*Client, error) { - c := &Client{ - cfg: cfg, - httpClient: &http.Client{Timeout: 30 * time.Second}, - } + ctx := context.Background() - if err := c.discoverEndpoints(); err != nil { - return nil, err - } + // Use custom HTTP client with timeout + httpClient := &http.Client{Timeout: 30 * time.Second} + ctx = oidc.ClientContext(ctx, httpClient) - return c, nil -} - -func (c *Client) discoverEndpoints() error { - wellKnownURL := strings.TrimRight(c.cfg.OIDC.IssuerURL, "/") + "/.well-known/openid-configuration" - resp, err := c.httpClient.Get(wellKnownURL) + provider, err := oidc.NewProvider(ctx, cfg.OIDC.IssuerURL) if err != nil { - return fmt.Errorf("failed to fetch discovery document: %w", err) + return nil, fmt.Errorf("failed to discover OIDC provider: %w", err) } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("discovery request failed with status: %d", resp.StatusCode) + // Determine AuthStyle based on AuthMethod + var authStyle oauth2.AuthStyle + switch cfg.OIDC.AuthMethod { + case "client_secret_post", "post": + authStyle = oauth2.AuthStyleInParams + case "client_secret_basic", "basic", "": // Default to basic if not specified + authStyle = oauth2.AuthStyleInHeader + default: + return nil, fmt.Errorf("unsupported auth method: %s", cfg.OIDC.AuthMethod) } - var endpoints providerEndpoints - if err := json.NewDecoder(resp.Body).Decode(&endpoints); err != nil { - return fmt.Errorf("failed to decode discovery document: %w", err) + oauth2Config := &oauth2.Config{ + ClientID: cfg.OIDC.ClientID, + ClientSecret: cfg.OIDC.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: provider.Endpoint().AuthURL, + TokenURL: provider.Endpoint().TokenURL, + AuthStyle: authStyle, // Set AuthStyle here + }, + Scopes: cfg.OIDC.Scopes, } - c.endpoints = &endpoints - return nil + return &Client{ + cfg: cfg, + provider: provider, + oauth2Config: oauth2Config, + }, nil } -func (c *Client) RefreshToken(refreshToken string) (*TokenResponse, error) { - log.Info().Msg("Refreshing token...") - - data := url.Values{} - data.Set("grant_type", "refresh_token") - data.Set("refresh_token", refreshToken) - - return c.makeTokenRequest(data) -} - -func (c *Client) GetToken(username, password string) (*TokenResponse, error) { - data := url.Values{} - data.Set("scope", strings.Join(c.cfg.OIDC.Scopes, " ")) +func (c *Client) GetToken(username, password string) (*oauth2.Token, error) { + ctx := context.Background() // Use config credentials if provided, otherwise fallback to args or client credentials user := username @@ -88,69 +74,57 @@ func (c *Client) GetToken(username, password string) (*TokenResponse, error) { pass = c.cfg.User.Password } + var token *oauth2.Token + var err error + if user != "" && pass != "" { log.Info().Str("grant_type", "password").Msg("Using Resource Owner Password Credentials flow") - data.Set("grant_type", "password") - data.Set("username", user) - data.Set("password", pass) + token, err = c.oauth2Config.PasswordCredentialsToken(ctx, user, pass) } else { log.Info().Str("grant_type", "client_credentials").Msg("Using Client Credentials flow") - data.Set("grant_type", "client_credentials") - } - - return c.makeTokenRequest(data) -} - -func (c *Client) makeTokenRequest(data url.Values) (*TokenResponse, error) { - // Handle Auth Method - // Handle Auth Method - // Default to basic - // RFC 6749 says client_id in body is NOT RECOMMENDED for Basic Auth, - // but we'll leave it out to be strict. - // If the user wants it in body, they should use "post" or we'd need a "basic_with_body" option. - // For now, let's stick to strict Basic Auth. - if c.cfg.OIDC.AuthMethod == "post" { - data.Set("client_id", c.cfg.OIDC.ClientID) - data.Set("client_secret", c.cfg.OIDC.ClientSecret) + // For client credentials, we need to create a clientcredentials.Config + ccConfig := clientcredentials.Config{ + ClientID: c.oauth2Config.ClientID, + ClientSecret: c.oauth2Config.ClientSecret, + TokenURL: c.oauth2Config.Endpoint.TokenURL, + Scopes: c.oauth2Config.Scopes, + AuthStyle: c.oauth2Config.Endpoint.AuthStyle, + } + // The clientcredentials.Config should use the http client set in the context + token, err = ccConfig.Token(ctx) } - log.Debug(). - Str("endpoint", c.endpoints.TokenEndpoint). - Str("auth_method", c.cfg.OIDC.AuthMethod). - Str("grant_type", data.Get("grant_type")). - Msg("Making token request") - - req, err := http.NewRequest("POST", c.endpoints.TokenEndpoint, strings.NewReader(data.Encode())) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to get token: %w", err) } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - if c.cfg.OIDC.AuthMethod != "post" { - req.SetBasicAuth(c.cfg.OIDC.ClientID, c.cfg.OIDC.ClientSecret) + // Validate ID Token if present + if idTokenRaw, ok := token.Extra("id_token").(string); ok && idTokenRaw != "" { + verifier := c.provider.Verifier(&oidc.Config{ClientID: c.cfg.OIDC.ClientID}) + idToken, err := verifier.Verify(ctx, idTokenRaw) + if err != nil { + return nil, fmt.Errorf("failed to verify ID token: %w", err) + } + log.Debug(). + Str("issuer", idToken.Issuer). + Str("subject", idToken.Subject). + Msg("ID Token validated successfully") + } else { + log.Debug().Msg("No ID Token found or provided in response") } - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("token request failed: %w", err) - } - defer resp.Body.Close() + return token, nil +} - if resp.StatusCode != http.StatusOK { - // Try to read body for error details - var errResp map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { - log.Debug().Err(err).Msg("Failed to decode error response body") - } - log.Debug().Interface("error_response", errResp).Msg("Token request failed") - return nil, fmt.Errorf("token request returned status %d: %v", resp.StatusCode, errResp) - } +// RefreshToken refreshes an expired token using the oauth2 library. +// It takes the existing *oauth2.Token which must contain a valid RefreshToken. +func (c *Client) RefreshToken(oldToken *oauth2.Token) (*oauth2.Token, error) { + ctx := context.Background() - var tokenResp TokenResponse - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return nil, fmt.Errorf("failed to decode token response: %w", err) + tokenSource := c.oauth2Config.TokenSource(ctx, oldToken) + newToken, err := tokenSource.Token() + if err != nil { + return nil, fmt.Errorf("failed to refresh token: %w", err) } - - return &tokenResp, nil + return newToken, nil } diff --git a/internal/oidc/client_test.go b/internal/oidc/client_test.go index d25fbdf..8159350 100644 --- a/internal/oidc/client_test.go +++ b/internal/oidc/client_test.go @@ -5,17 +5,35 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/codozor/authk/internal/config" + "golang.org/x/oauth2" ) +type mockTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int `json:"expires_in"` // seconds + IDToken string `json:"id_token,omitempty"` +} + func TestClient_GetToken(t *testing.T) { // Mock OIDC Provider - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var testServer *httptest.Server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/.well-known/openid-configuration": - if err := json.NewEncoder(w).Encode(map[string]string{ - "token_endpoint": "http://" + r.Host + "/token", + // go-oidc requires an issuer that matches the discovery URL + // and a jwks_uri for ID token validation (even if we don't validate it in this test) + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": testServer.URL, + "token_endpoint": testServer.URL + "/token", + "jwks_uri": testServer.URL + "/certs", // Dummy JWKS URI + "response_types_supported": []string{"code"}, // Minimal required by go-oidc + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"RS256"}, }); err != nil { t.Error(err) } @@ -24,11 +42,13 @@ func TestClient_GetToken(t *testing.T) { t.Error(err) } if r.Form.Get("grant_type") == "client_credentials" { - if err := json.NewEncoder(w).Encode(TokenResponse{ + w.Header().Set("Content-Type", "application/json") + resp := mockTokenResponse{ AccessToken: "mock_access_token", ExpiresIn: 3600, TokenType: "Bearer", - }); err != nil { + } + if err := json.NewEncoder(w).Encode(resp); err != nil { t.Error(err) } } else { @@ -37,15 +57,16 @@ func TestClient_GetToken(t *testing.T) { default: w.WriteHeader(http.StatusNotFound) } - })) - defer ts.Close() + }) + testServer = httptest.NewServer(handler) + defer testServer.Close() cfg := &config.Config{ OIDC: config.OIDCConfig{ - IssuerURL: ts.URL, + IssuerURL: testServer.URL, ClientID: "client", ClientSecret: "secret", - AuthMethod: "basic", + AuthMethod: "client_secret_basic", }, } @@ -62,14 +83,92 @@ func TestClient_GetToken(t *testing.T) { if token.AccessToken != "mock_access_token" { t.Errorf("expected access token 'mock_access_token', got %s", token.AccessToken) } + if token.TokenType != "Bearer" { + t.Errorf("expected token type 'Bearer', got %s", token.TokenType) + } + if token.Expiry.IsZero() { + t.Error("expected token expiry to be set") + } +} + +func TestClient_GetToken_Password(t *testing.T) { + // Mock OIDC Provider + var testServer *httptest.Server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": testServer.URL, + "token_endpoint": testServer.URL + "/token", + "jwks_uri": testServer.URL + "/certs", + "response_types_supported": []string{"code"}, + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"RS256"}, + }); err != nil { + t.Error(err) + } + case "/token": + if err := r.ParseForm(); err != nil { + t.Error(err) + } + if r.Form.Get("grant_type") == "password" && r.Form.Get("username") == "testuser" && r.Form.Get("password") == "testpass" { + w.Header().Set("Content-Type", "application/json") + resp := mockTokenResponse{ + AccessToken: "mock_password_access_token", + ExpiresIn: 3600, + TokenType: "Bearer", + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Error(err) + } + } else { + w.WriteHeader(http.StatusBadRequest) + } + default: + w.WriteHeader(http.StatusNotFound) + } + }) + testServer = httptest.NewServer(handler) + defer testServer.Close() + + cfg := &config.Config{ + OIDC: config.OIDCConfig{ + IssuerURL: testServer.URL, + ClientID: "client", + ClientSecret: "secret", + AuthMethod: "client_secret_basic", + }, + } + + client, err := NewClient(cfg) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + + token, err := client.GetToken("testuser", "testpass") + if err != nil { + t.Fatalf("GetToken() error = %v", err) + } + + if token.AccessToken != "mock_password_access_token" { + t.Errorf("expected access token 'mock_password_access_token', got %s", token.AccessToken) + } } func TestClient_RefreshToken(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var testServer *httptest.Server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/.well-known/openid-configuration": - if err := json.NewEncoder(w).Encode(map[string]string{ - "token_endpoint": "http://" + r.Host + "/token", + // go-oidc requires an issuer that matches the discovery URL + // and a jwks_uri for ID token validation (even if we don't validate it in this test) + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": testServer.URL, + "token_endpoint": testServer.URL + "/token", + "jwks_uri": testServer.URL + "/certs", // Dummy JWKS URI + "response_types_supported": []string{"code"}, // Minimal required by go-oidc + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"RS256"}, }); err != nil { t.Error(err) } @@ -78,22 +177,36 @@ func TestClient_RefreshToken(t *testing.T) { t.Error(err) } if r.Form.Get("grant_type") == "refresh_token" && r.Form.Get("refresh_token") == "valid_refresh" { - if err := json.NewEncoder(w).Encode(TokenResponse{ + w.Header().Set("Content-Type", "application/json") + resp := mockTokenResponse{ AccessToken: "new_access_token", ExpiresIn: 3600, - }); err != nil { + TokenType: "Bearer", + } + if err := json.NewEncoder(w).Encode(resp); err != nil { t.Error(err) } } else { w.WriteHeader(http.StatusBadRequest) } + case "/certs": + // Provide a minimal JWKS endpoint for go-oidc + if _, err := w.Write([]byte(`{"keys":[]}`)); err != nil { + t.Errorf("w.Write failed: %v", err) + } + default: + w.WriteHeader(http.StatusNotFound) } - })) - defer ts.Close() + }) + testServer = httptest.NewServer(handler) + defer testServer.Close() cfg := &config.Config{ OIDC: config.OIDCConfig{ - IssuerURL: ts.URL, + IssuerURL: testServer.URL, + ClientID: "client", + ClientSecret: "secret", + AuthMethod: "client_secret_basic", }, } @@ -102,7 +215,13 @@ func TestClient_RefreshToken(t *testing.T) { t.Fatalf("NewClient() error = %v", err) } - token, err := client.RefreshToken("valid_refresh") + // Create a dummy old token with the refresh token + oldToken := &oauth2.Token{ + RefreshToken: "valid_refresh", + Expiry: time.Now().Add(-1 * time.Hour), // Expired to force refresh + } + + token, err := client.RefreshToken(oldToken) if err != nil { t.Fatalf("RefreshToken() error = %v", err) } @@ -110,4 +229,10 @@ func TestClient_RefreshToken(t *testing.T) { if token.AccessToken != "new_access_token" { t.Errorf("expected access token 'new_access_token', got %s", token.AccessToken) } + if token.TokenType != "Bearer" { + t.Errorf("expected token type 'Bearer', got %s", token.TokenType) + } + if token.Expiry.IsZero() { + t.Error("expected token expiry to be set") + } }