Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cmd/authk/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,18 @@ updating a .env file with the valid token.`,

// Maintenance Loop
for {
// Calculate sleep time (expires_in - 60s buffer)
sleepDuration := time.Duration(token.ExpiresIn-60) * time.Second
if sleepDuration < 10*time.Second {
// Calculate sleep time based on token expiry and a refresh buffer
refreshBuffer := 60 * time.Second // Refresh 60 seconds before expiry
sleepDuration := time.Until(token.Expiry) - refreshBuffer
if sleepDuration < 10*time.Second { // Ensure at least 10 seconds sleep
sleepDuration = 10 * time.Second
}

log.Info().Dur("sleep_duration", sleepDuration).Msg("Waiting for token refresh")
time.Sleep(sleepDuration)

newToken, err := client.RefreshToken(token.RefreshToken)
// Attempt to refresh the token
newToken, err := client.RefreshToken(token)
if err != nil {
log.Error().Err(err).Msg("Failed to refresh token, attempting full re-authentication")

Expand All @@ -107,7 +109,9 @@ updating a .env file with the valid token.`,
time.Sleep(10 * time.Second)

// Force short sleep on next iteration to retry quickly
token.ExpiresIn = 0
// By setting token.Expiry to now, time.Until will be negative,
// and sleepDuration will become 10s.
token.Expiry = time.Now()
continue
}
}
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ go 1.25.1

require (
cuelang.org/go v0.15.1
github.com/coreos/go-oidc/v3 v3.17.0
github.com/fatih/color v1.18.0
github.com/rs/zerolog v1.34.0
github.com/spf13/cobra v1.10.1
golang.org/x/oauth2 v0.33.0
)

require (
github.com/cockroachdb/apd/v3 v3.2.1 // indirect
github.com/emicklei/proto v1.14.2 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
Expand Down
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ cuelang.org/go v0.15.1 h1:MRnjc/KJE+K42rnJ3a+425f1jqXeOOgq9SK4tYRTtWw=
cuelang.org/go v0.15.1/go.mod h1:NYw6n4akZcTjA7QQwJ1/gqWrrhsN4aZwhcAL0jv9rZE=
github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg=
github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc=
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I=
github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
Expand Down Expand Up @@ -61,8 +65,8 @@ golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
29 changes: 29 additions & 0 deletions internal/env/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,32 @@ func TestFind(t *testing.T) {
t.Errorf("Find() = %s, want %s", found, envFile)
}
}

func TestFind_NotFound(t *testing.T) {
tmpDir := t.TempDir()
wd, _ := os.Getwd()
defer func() { _ = os.Chdir(wd) }()
if err := os.Chdir(tmpDir); err != nil {
t.Fatal(err)
}

// Ensure the file doesn't exist in tmpDir or parents (unlikely but possible if running in root)
// And strictly speaking, we should ensure it's not in home dir either for the fallback.
// But a random name is safe enough.
_, err := Find("non_existent_file_random_12345")
if err == nil {
t.Error("Find() expected error for non-existent file, got nil")
}
}

func TestFind_WithSeparator(t *testing.T) {
// On Linux, this is an absolute path
path := "/tmp/foo/.env"
found, err := Find(path)
if err != nil {
t.Fatal(err)
}
if found != path {
t.Errorf("Find() = %s, want %s", found, path)
}
}
182 changes: 78 additions & 104 deletions internal/oidc/client.go
Original file line number Diff line number Diff line change
@@ -1,82 +1,68 @@
package oidc

import (
"encoding/json"
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/codozor/authk/internal/config"
"github.com/rs/zerolog/log"
)

type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

type Client struct {
cfg *config.Config
httpClient *http.Client
endpoints *providerEndpoints
}

type providerEndpoints struct {
TokenEndpoint string `json:"token_endpoint"`
cfg *config.Config
provider *oidc.Provider
oauth2Config *oauth2.Config
}

func NewClient(cfg *config.Config) (*Client, error) {
c := &Client{
cfg: cfg,
httpClient: &http.Client{Timeout: 30 * time.Second},
}
ctx := context.Background()

if err := c.discoverEndpoints(); err != nil {
return nil, err
}
// Use custom HTTP client with timeout
httpClient := &http.Client{Timeout: 30 * time.Second}
ctx = oidc.ClientContext(ctx, httpClient)

return c, nil
}

func (c *Client) discoverEndpoints() error {
wellKnownURL := strings.TrimRight(c.cfg.OIDC.IssuerURL, "/") + "/.well-known/openid-configuration"
resp, err := c.httpClient.Get(wellKnownURL)
provider, err := oidc.NewProvider(ctx, cfg.OIDC.IssuerURL)
if err != nil {
return fmt.Errorf("failed to fetch discovery document: %w", err)
return nil, fmt.Errorf("failed to discover OIDC provider: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("discovery request failed with status: %d", resp.StatusCode)
// Determine AuthStyle based on AuthMethod
var authStyle oauth2.AuthStyle
switch cfg.OIDC.AuthMethod {
case "client_secret_post", "post":
authStyle = oauth2.AuthStyleInParams
case "client_secret_basic", "basic", "": // Default to basic if not specified
authStyle = oauth2.AuthStyleInHeader
default:
return nil, fmt.Errorf("unsupported auth method: %s", cfg.OIDC.AuthMethod)
}

var endpoints providerEndpoints
if err := json.NewDecoder(resp.Body).Decode(&endpoints); err != nil {
return fmt.Errorf("failed to decode discovery document: %w", err)
oauth2Config := &oauth2.Config{
ClientID: cfg.OIDC.ClientID,
ClientSecret: cfg.OIDC.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: provider.Endpoint().AuthURL,
TokenURL: provider.Endpoint().TokenURL,
AuthStyle: authStyle, // Set AuthStyle here
},
Scopes: cfg.OIDC.Scopes,
}

c.endpoints = &endpoints
return nil
return &Client{
cfg: cfg,
provider: provider,
oauth2Config: oauth2Config,
}, nil
}

func (c *Client) RefreshToken(refreshToken string) (*TokenResponse, error) {
log.Info().Msg("Refreshing token...")

data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)

return c.makeTokenRequest(data)
}

func (c *Client) GetToken(username, password string) (*TokenResponse, error) {
data := url.Values{}
data.Set("scope", strings.Join(c.cfg.OIDC.Scopes, " "))
func (c *Client) GetToken(username, password string) (*oauth2.Token, error) {
ctx := context.Background()

// Use config credentials if provided, otherwise fallback to args or client credentials
user := username
Expand All @@ -88,69 +74,57 @@ func (c *Client) GetToken(username, password string) (*TokenResponse, error) {
pass = c.cfg.User.Password
}

var token *oauth2.Token
var err error

if user != "" && pass != "" {
log.Info().Str("grant_type", "password").Msg("Using Resource Owner Password Credentials flow")
data.Set("grant_type", "password")
data.Set("username", user)
data.Set("password", pass)
token, err = c.oauth2Config.PasswordCredentialsToken(ctx, user, pass)
} else {
log.Info().Str("grant_type", "client_credentials").Msg("Using Client Credentials flow")
data.Set("grant_type", "client_credentials")
}

return c.makeTokenRequest(data)
}

func (c *Client) makeTokenRequest(data url.Values) (*TokenResponse, error) {
// Handle Auth Method
// Handle Auth Method
// Default to basic
// RFC 6749 says client_id in body is NOT RECOMMENDED for Basic Auth,
// but we'll leave it out to be strict.
// If the user wants it in body, they should use "post" or we'd need a "basic_with_body" option.
// For now, let's stick to strict Basic Auth.
if c.cfg.OIDC.AuthMethod == "post" {
data.Set("client_id", c.cfg.OIDC.ClientID)
data.Set("client_secret", c.cfg.OIDC.ClientSecret)
// For client credentials, we need to create a clientcredentials.Config
ccConfig := clientcredentials.Config{
ClientID: c.oauth2Config.ClientID,
ClientSecret: c.oauth2Config.ClientSecret,
TokenURL: c.oauth2Config.Endpoint.TokenURL,
Scopes: c.oauth2Config.Scopes,
AuthStyle: c.oauth2Config.Endpoint.AuthStyle,
}
// The clientcredentials.Config should use the http client set in the context
token, err = ccConfig.Token(ctx)
}

log.Debug().
Str("endpoint", c.endpoints.TokenEndpoint).
Str("auth_method", c.cfg.OIDC.AuthMethod).
Str("grant_type", data.Get("grant_type")).
Msg("Making token request")

req, err := http.NewRequest("POST", c.endpoints.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, fmt.Errorf("failed to get token: %w", err)
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

if c.cfg.OIDC.AuthMethod != "post" {
req.SetBasicAuth(c.cfg.OIDC.ClientID, c.cfg.OIDC.ClientSecret)
// Validate ID Token if present
if idTokenRaw, ok := token.Extra("id_token").(string); ok && idTokenRaw != "" {
verifier := c.provider.Verifier(&oidc.Config{ClientID: c.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(ctx, idTokenRaw)
if err != nil {
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
log.Debug().
Str("issuer", idToken.Issuer).
Str("subject", idToken.Subject).
Msg("ID Token validated successfully")
} else {
log.Debug().Msg("No ID Token found or provided in response")
}

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
return token, nil
}

if resp.StatusCode != http.StatusOK {
// Try to read body for error details
var errResp map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil {
log.Debug().Err(err).Msg("Failed to decode error response body")
}
log.Debug().Interface("error_response", errResp).Msg("Token request failed")
return nil, fmt.Errorf("token request returned status %d: %v", resp.StatusCode, errResp)
}
// RefreshToken refreshes an expired token using the oauth2 library.
// It takes the existing *oauth2.Token which must contain a valid RefreshToken.
func (c *Client) RefreshToken(oldToken *oauth2.Token) (*oauth2.Token, error) {
ctx := context.Background()

var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
tokenSource := c.oauth2Config.TokenSource(ctx, oldToken)
newToken, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}

return &tokenResp, nil
return newToken, nil
}
Loading