Skip to content

Commit 0734aea

Browse files
✨ feat(oidc): Refactor OIDC client with golang.org/x/oauth2 and go-oidc (#6)
* ✨ feat(oidc): Refactor OIDC client with golang.org/x/oauth2 and go-oidc This commit refactors the `internal/oidc` package to leverage the `golang.org/x/oauth2` and `github.com/coreos/go-oidc` libraries. The manual implementation of OAuth2/OIDC flows, including discovery, token requests, and JSON parsing, has been replaced by robust, industry-standard libraries. Key changes include: - Replaced manual OIDC discovery with `oidc.NewProvider`. - Updated token retrieval (password and client credentials grants) to use `oauth2.Config` and `clientcredentials.Config`. - Refactored token refreshing to use `oauth2.TokenSource` mechanisms. - Eliminated custom `TokenResponse` struct in favor of `oauth2.Token`. - Enhanced test suite (`internal/oidc/client_test.go`) to reflect new implementation and ensure compatibility with `go-oidc` expectations. - Updated `cmd/authk/root.go` to use `token.Expiry` for refresh timing, removing custom `ExpiresIn` logic. This refactoring significantly improves security (due to strict validation of OIDC specs by go-oidc), maintainability, and reduces the amount of custom code. * ✅ feat: Improve test coverage and fix linting issue This commit addresses the recent drop in test coverage reported by Coveralls and fixes a linting issue. Changes include: - **internal/oidc/client_test.go:** - Added error checking for `w.Write` call in mock server to resolve an `errcheck` linting error. - Introduced `TestClient_GetToken_Password` to specifically test the Resource Owner Password Credentials flow, increasing coverage for `GetToken` function. - **internal/env/env_test.go:** - Added `TestFind_NotFound` to verify error handling when a file is not found. - Added `TestFind_WithSeparator` to test `Find` function behavior with paths containing separators. These changes collectively improve the overall test coverage and code quality. * 🐛 feat(oidc): Support "basic" and "post" auth methods The OIDC client now correctly handles "basic" and "post" as authentication method configurations, aligning with the `schema.cue` definition. Previously, only "client_secret_basic" and "client_secret_post" were recognized, leading to an "unsupported auth method" error when "basic" or "post" were used in the configuration. This change ensures that the OIDC client initialization works as expected with the simplified auth method names defined in the schema, while maintaining backward compatibility with the more verbose OIDC standard names.
1 parent 80a4a36 commit 0734aea

6 files changed

Lines changed: 268 additions & 129 deletions

File tree

cmd/authk/root.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,18 @@ updating a .env file with the valid token.`,
8686

8787
// Maintenance Loop
8888
for {
89-
// Calculate sleep time (expires_in - 60s buffer)
90-
sleepDuration := time.Duration(token.ExpiresIn-60) * time.Second
91-
if sleepDuration < 10*time.Second {
89+
// Calculate sleep time based on token expiry and a refresh buffer
90+
refreshBuffer := 60 * time.Second // Refresh 60 seconds before expiry
91+
sleepDuration := time.Until(token.Expiry) - refreshBuffer
92+
if sleepDuration < 10*time.Second { // Ensure at least 10 seconds sleep
9293
sleepDuration = 10 * time.Second
9394
}
9495

9596
log.Info().Dur("sleep_duration", sleepDuration).Msg("Waiting for token refresh")
9697
time.Sleep(sleepDuration)
9798

98-
newToken, err := client.RefreshToken(token.RefreshToken)
99+
// Attempt to refresh the token
100+
newToken, err := client.RefreshToken(token)
99101
if err != nil {
100102
log.Error().Err(err).Msg("Failed to refresh token, attempting full re-authentication")
101103

@@ -107,7 +109,9 @@ updating a .env file with the valid token.`,
107109
time.Sleep(10 * time.Second)
108110

109111
// Force short sleep on next iteration to retry quickly
110-
token.ExpiresIn = 0
112+
// By setting token.Expiry to now, time.Until will be negative,
113+
// and sleepDuration will become 10s.
114+
token.Expiry = time.Now()
111115
continue
112116
}
113117
}

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@ go 1.25.1
44

55
require (
66
cuelang.org/go v0.15.1
7+
github.com/coreos/go-oidc/v3 v3.17.0
78
github.com/fatih/color v1.18.0
89
github.com/rs/zerolog v1.34.0
910
github.com/spf13/cobra v1.10.1
11+
golang.org/x/oauth2 v0.33.0
1012
)
1113

1214
require (
1315
github.com/cockroachdb/apd/v3 v3.2.1 // indirect
1416
github.com/emicklei/proto v1.14.2 // indirect
17+
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
1518
github.com/google/uuid v1.6.0 // indirect
1619
github.com/inconshreveable/mousetrap v1.1.0 // indirect
1720
github.com/mattn/go-colorable v0.1.13 // indirect

go.sum

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ cuelang.org/go v0.15.1 h1:MRnjc/KJE+K42rnJ3a+425f1jqXeOOgq9SK4tYRTtWw=
44
cuelang.org/go v0.15.1/go.mod h1:NYw6n4akZcTjA7QQwJ1/gqWrrhsN4aZwhcAL0jv9rZE=
55
github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg=
66
github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc=
7+
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
8+
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
79
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
810
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
911
github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I=
1012
github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
1113
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
1214
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
15+
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
16+
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
1317
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
1418
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
1519
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
@@ -61,8 +65,8 @@ golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
6165
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
6266
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
6367
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
64-
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
65-
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
68+
golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo=
69+
golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
6670
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
6771
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
6872
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

internal/env/env_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,32 @@ func TestFind(t *testing.T) {
149149
t.Errorf("Find() = %s, want %s", found, envFile)
150150
}
151151
}
152+
153+
func TestFind_NotFound(t *testing.T) {
154+
tmpDir := t.TempDir()
155+
wd, _ := os.Getwd()
156+
defer func() { _ = os.Chdir(wd) }()
157+
if err := os.Chdir(tmpDir); err != nil {
158+
t.Fatal(err)
159+
}
160+
161+
// Ensure the file doesn't exist in tmpDir or parents (unlikely but possible if running in root)
162+
// And strictly speaking, we should ensure it's not in home dir either for the fallback.
163+
// But a random name is safe enough.
164+
_, err := Find("non_existent_file_random_12345")
165+
if err == nil {
166+
t.Error("Find() expected error for non-existent file, got nil")
167+
}
168+
}
169+
170+
func TestFind_WithSeparator(t *testing.T) {
171+
// On Linux, this is an absolute path
172+
path := "/tmp/foo/.env"
173+
found, err := Find(path)
174+
if err != nil {
175+
t.Fatal(err)
176+
}
177+
if found != path {
178+
t.Errorf("Find() = %s, want %s", found, path)
179+
}
180+
}

internal/oidc/client.go

Lines changed: 78 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,68 @@
11
package oidc
22

33
import (
4-
"encoding/json"
4+
"context"
55
"fmt"
66
"net/http"
7-
"net/url"
8-
"strings"
97
"time"
108

119
"github.com/codozor/authk/internal/config"
1210
"github.com/rs/zerolog/log"
13-
)
1411

15-
type TokenResponse struct {
16-
AccessToken string `json:"access_token"`
17-
RefreshToken string `json:"refresh_token"`
18-
IDToken string `json:"id_token"`
19-
TokenType string `json:"token_type"`
20-
ExpiresIn int `json:"expires_in"`
21-
}
12+
"github.com/coreos/go-oidc/v3/oidc"
13+
"golang.org/x/oauth2"
14+
"golang.org/x/oauth2/clientcredentials"
15+
)
2216

2317
type Client struct {
24-
cfg *config.Config
25-
httpClient *http.Client
26-
endpoints *providerEndpoints
27-
}
28-
29-
type providerEndpoints struct {
30-
TokenEndpoint string `json:"token_endpoint"`
18+
cfg *config.Config
19+
provider *oidc.Provider
20+
oauth2Config *oauth2.Config
3121
}
3222

3323
func NewClient(cfg *config.Config) (*Client, error) {
34-
c := &Client{
35-
cfg: cfg,
36-
httpClient: &http.Client{Timeout: 30 * time.Second},
37-
}
24+
ctx := context.Background()
3825

39-
if err := c.discoverEndpoints(); err != nil {
40-
return nil, err
41-
}
26+
// Use custom HTTP client with timeout
27+
httpClient := &http.Client{Timeout: 30 * time.Second}
28+
ctx = oidc.ClientContext(ctx, httpClient)
4229

43-
return c, nil
44-
}
45-
46-
func (c *Client) discoverEndpoints() error {
47-
wellKnownURL := strings.TrimRight(c.cfg.OIDC.IssuerURL, "/") + "/.well-known/openid-configuration"
48-
resp, err := c.httpClient.Get(wellKnownURL)
30+
provider, err := oidc.NewProvider(ctx, cfg.OIDC.IssuerURL)
4931
if err != nil {
50-
return fmt.Errorf("failed to fetch discovery document: %w", err)
32+
return nil, fmt.Errorf("failed to discover OIDC provider: %w", err)
5133
}
52-
defer resp.Body.Close()
5334

54-
if resp.StatusCode != http.StatusOK {
55-
return fmt.Errorf("discovery request failed with status: %d", resp.StatusCode)
35+
// Determine AuthStyle based on AuthMethod
36+
var authStyle oauth2.AuthStyle
37+
switch cfg.OIDC.AuthMethod {
38+
case "client_secret_post", "post":
39+
authStyle = oauth2.AuthStyleInParams
40+
case "client_secret_basic", "basic", "": // Default to basic if not specified
41+
authStyle = oauth2.AuthStyleInHeader
42+
default:
43+
return nil, fmt.Errorf("unsupported auth method: %s", cfg.OIDC.AuthMethod)
5644
}
5745

58-
var endpoints providerEndpoints
59-
if err := json.NewDecoder(resp.Body).Decode(&endpoints); err != nil {
60-
return fmt.Errorf("failed to decode discovery document: %w", err)
46+
oauth2Config := &oauth2.Config{
47+
ClientID: cfg.OIDC.ClientID,
48+
ClientSecret: cfg.OIDC.ClientSecret,
49+
Endpoint: oauth2.Endpoint{
50+
AuthURL: provider.Endpoint().AuthURL,
51+
TokenURL: provider.Endpoint().TokenURL,
52+
AuthStyle: authStyle, // Set AuthStyle here
53+
},
54+
Scopes: cfg.OIDC.Scopes,
6155
}
6256

63-
c.endpoints = &endpoints
64-
return nil
57+
return &Client{
58+
cfg: cfg,
59+
provider: provider,
60+
oauth2Config: oauth2Config,
61+
}, nil
6562
}
6663

67-
func (c *Client) RefreshToken(refreshToken string) (*TokenResponse, error) {
68-
log.Info().Msg("Refreshing token...")
69-
70-
data := url.Values{}
71-
data.Set("grant_type", "refresh_token")
72-
data.Set("refresh_token", refreshToken)
73-
74-
return c.makeTokenRequest(data)
75-
}
76-
77-
func (c *Client) GetToken(username, password string) (*TokenResponse, error) {
78-
data := url.Values{}
79-
data.Set("scope", strings.Join(c.cfg.OIDC.Scopes, " "))
64+
func (c *Client) GetToken(username, password string) (*oauth2.Token, error) {
65+
ctx := context.Background()
8066

8167
// Use config credentials if provided, otherwise fallback to args or client credentials
8268
user := username
@@ -88,69 +74,57 @@ func (c *Client) GetToken(username, password string) (*TokenResponse, error) {
8874
pass = c.cfg.User.Password
8975
}
9076

77+
var token *oauth2.Token
78+
var err error
79+
9180
if user != "" && pass != "" {
9281
log.Info().Str("grant_type", "password").Msg("Using Resource Owner Password Credentials flow")
93-
data.Set("grant_type", "password")
94-
data.Set("username", user)
95-
data.Set("password", pass)
82+
token, err = c.oauth2Config.PasswordCredentialsToken(ctx, user, pass)
9683
} else {
9784
log.Info().Str("grant_type", "client_credentials").Msg("Using Client Credentials flow")
98-
data.Set("grant_type", "client_credentials")
99-
}
100-
101-
return c.makeTokenRequest(data)
102-
}
103-
104-
func (c *Client) makeTokenRequest(data url.Values) (*TokenResponse, error) {
105-
// Handle Auth Method
106-
// Handle Auth Method
107-
// Default to basic
108-
// RFC 6749 says client_id in body is NOT RECOMMENDED for Basic Auth,
109-
// but we'll leave it out to be strict.
110-
// If the user wants it in body, they should use "post" or we'd need a "basic_with_body" option.
111-
// For now, let's stick to strict Basic Auth.
112-
if c.cfg.OIDC.AuthMethod == "post" {
113-
data.Set("client_id", c.cfg.OIDC.ClientID)
114-
data.Set("client_secret", c.cfg.OIDC.ClientSecret)
85+
// For client credentials, we need to create a clientcredentials.Config
86+
ccConfig := clientcredentials.Config{
87+
ClientID: c.oauth2Config.ClientID,
88+
ClientSecret: c.oauth2Config.ClientSecret,
89+
TokenURL: c.oauth2Config.Endpoint.TokenURL,
90+
Scopes: c.oauth2Config.Scopes,
91+
AuthStyle: c.oauth2Config.Endpoint.AuthStyle,
92+
}
93+
// The clientcredentials.Config should use the http client set in the context
94+
token, err = ccConfig.Token(ctx)
11595
}
11696

117-
log.Debug().
118-
Str("endpoint", c.endpoints.TokenEndpoint).
119-
Str("auth_method", c.cfg.OIDC.AuthMethod).
120-
Str("grant_type", data.Get("grant_type")).
121-
Msg("Making token request")
122-
123-
req, err := http.NewRequest("POST", c.endpoints.TokenEndpoint, strings.NewReader(data.Encode()))
12497
if err != nil {
125-
return nil, fmt.Errorf("failed to create request: %w", err)
98+
return nil, fmt.Errorf("failed to get token: %w", err)
12699
}
127100

128-
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
129-
130-
if c.cfg.OIDC.AuthMethod != "post" {
131-
req.SetBasicAuth(c.cfg.OIDC.ClientID, c.cfg.OIDC.ClientSecret)
101+
// Validate ID Token if present
102+
if idTokenRaw, ok := token.Extra("id_token").(string); ok && idTokenRaw != "" {
103+
verifier := c.provider.Verifier(&oidc.Config{ClientID: c.cfg.OIDC.ClientID})
104+
idToken, err := verifier.Verify(ctx, idTokenRaw)
105+
if err != nil {
106+
return nil, fmt.Errorf("failed to verify ID token: %w", err)
107+
}
108+
log.Debug().
109+
Str("issuer", idToken.Issuer).
110+
Str("subject", idToken.Subject).
111+
Msg("ID Token validated successfully")
112+
} else {
113+
log.Debug().Msg("No ID Token found or provided in response")
132114
}
133115

134-
resp, err := c.httpClient.Do(req)
135-
if err != nil {
136-
return nil, fmt.Errorf("token request failed: %w", err)
137-
}
138-
defer resp.Body.Close()
116+
return token, nil
117+
}
139118

140-
if resp.StatusCode != http.StatusOK {
141-
// Try to read body for error details
142-
var errResp map[string]interface{}
143-
if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil {
144-
log.Debug().Err(err).Msg("Failed to decode error response body")
145-
}
146-
log.Debug().Interface("error_response", errResp).Msg("Token request failed")
147-
return nil, fmt.Errorf("token request returned status %d: %v", resp.StatusCode, errResp)
148-
}
119+
// RefreshToken refreshes an expired token using the oauth2 library.
120+
// It takes the existing *oauth2.Token which must contain a valid RefreshToken.
121+
func (c *Client) RefreshToken(oldToken *oauth2.Token) (*oauth2.Token, error) {
122+
ctx := context.Background()
149123

150-
var tokenResp TokenResponse
151-
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
152-
return nil, fmt.Errorf("failed to decode token response: %w", err)
124+
tokenSource := c.oauth2Config.TokenSource(ctx, oldToken)
125+
newToken, err := tokenSource.Token()
126+
if err != nil {
127+
return nil, fmt.Errorf("failed to refresh token: %w", err)
153128
}
154-
155-
return &tokenResp, nil
129+
return newToken, nil
156130
}

0 commit comments

Comments
 (0)