diff --git a/internal/handlers/nuget_feed.go b/internal/handlers/nuget_feed.go index 34d1769..4191dc0 100644 --- a/internal/handlers/nuget_feed.go +++ b/internal/handlers/nuget_feed.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" "github.com/elazarl/goproxy" @@ -36,9 +35,8 @@ type nugetV3IndexResponse struct { // NugetFeedHandler handles requests to nuget feeds, adding auth. type NugetFeedHandler struct { - credentials []nugetFeedCredentials - oidcCredentials map[string]*oidc.OIDCCredential - mutex sync.RWMutex + credentials []nugetFeedCredentials + oidcRegistry *oidc.OIDCRegistry } type nugetFeedCredentials struct { @@ -52,8 +50,8 @@ type nugetFeedCredentials struct { // NewNugetFeedHandler returns a new NugetFeedHandler. func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler { handler := NugetFeedHandler{ - credentials: []nugetFeedCredentials{}, - oidcCredentials: make(map[string]*oidc.OIDCCredential), + credentials: []nugetFeedCredentials{}, + oidcRegistry: oidc.NewOIDCRegistry(), } httpClient := &http.Client{ @@ -72,58 +70,68 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler { username := cred.GetString("username") password := cred.GetString("password") - oidcCredential, _ := oidc.CreateOIDCCredential(cred) - if oidcCredential != nil { - key := url - if key == "" { - key = host - } + oidcCredential, _, ok := handler.oidcRegistry.Register(cred, []string{"url"}, "nuget feed") + if ok { + // Discover additional resource URLs from the nuget feed index. + // Host-only credentials (from the CLI) are still registered above + // for request-time matching, but discovery requires an absolute URL. + // Wrapped in a closure so defer runs promptly for each credential, + // ensuring the HTTP response body is always closed (pre-existing + // leak fixed here: the body was previously leaked on ReadAll error + // and on early-return status code paths). + if url != "" { + func() { + req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) + if err != nil { + logging.RequestLogf(nil, "error creating http request (%s): %v", url, err) + return + } - if key != "" { - handler.oidcCredentials[key] = oidcCredential - logging.RequestLogf(nil, "registered %s OIDC credentials for nuget feed: %s", oidcCredential.Provider(), key) + if req.URL.Scheme != "https" { + logging.RequestLogf(nil, "refusing to discover nuget feed over non-https URL %s", url) + return + } - // now query all resources to add to the authentication list - req, err := http.NewRequestWithContext(context.Background(), "GET", key, nil) - if err != nil { - logging.RequestLogf(nil, "error creating http request (%s): %v", key, err) - continue - } + if !handler.oidcRegistry.TryAuth(req, nil) { + return + } - if oidc.TryAuthOIDCRequestWithPrefix(&handler.mutex, handler.oidcCredentials, req, nil) { rawRsp, err := httpClient.Do(req) if err != nil { - logging.RequestLogf(nil, "error retrieving http response (%s): %v", key, err) - continue + logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err) + return } + defer rawRsp.Body.Close() body, err := io.ReadAll(rawRsp.Body) if err != nil { - logging.RequestLogf(nil, "error reading http response body") - continue + logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err) + return } - rawRsp.Body.Close() switch rawRsp.StatusCode { case 401, 403: - logging.RequestLogf(nil, "unauthorized for nuget feed %s", key) - continue + logging.RequestLogf(nil, "unauthorized for nuget feed %s", url) + return } if rawRsp.StatusCode >= 400 { - logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, key) - continue + logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url) + return } - urlsToAuthenticate := extraUrlsFromSourceResponse(body, key) - for _, url := range urlsToAuthenticate { - handler.oidcCredentials[url] = oidcCredential - logging.RequestLogf(nil, " registered %s OIDC credentials for nuget resource: %s", oidcCredential.Provider(), url) + urlsToAuthenticate := extraUrlsFromSourceResponse(body, url) + for _, discoveredURL := range urlsToAuthenticate { + handler.oidcRegistry.RegisterURL(discoveredURL, oidcCredential, "nuget resource") } - } + }() } continue } + // OIDC credentials are not used as static credentials. + if oidcCredential != nil { + continue + } feedCred := nugetFeedCredentials{ url: url, @@ -138,48 +146,52 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler { // and authenticate them all if url != "" { logging.RequestLogf(nil, "fetching service index for nuget feed %s", url) - req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) - authenticateNugetRequest(req, feedCred, nil) - if err != nil { - logging.RequestLogf(nil, "error creating http request (%s): %v", url, err) - continue - } + // Same closure pattern as the OIDC block above — ensures the + // HTTP response body is always closed via defer. + func() { + req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil) + if err != nil { + logging.RequestLogf(nil, "error creating http request (%s): %v", url, err) + return + } + authenticateNugetRequest(req, feedCred, nil) - rawRsp, err := httpClient.Do(req) - if err != nil { - logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err) - continue - } + rawRsp, err := httpClient.Do(req) + if err != nil { + logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err) + return + } + defer rawRsp.Body.Close() - body, err := io.ReadAll(rawRsp.Body) - if err != nil { - logging.RequestLogf(nil, "error reading http response body") - continue - } - rawRsp.Body.Close() + body, err := io.ReadAll(rawRsp.Body) + if err != nil { + logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err) + return + } - switch rawRsp.StatusCode { - case 401, 403: - logging.RequestLogf(nil, "unauthorized for nuget feed %s", url) - continue - } + switch rawRsp.StatusCode { + case 401, 403: + logging.RequestLogf(nil, "unauthorized for nuget feed %s", url) + return + } - if rawRsp.StatusCode >= 400 { - logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url) - continue - } + if rawRsp.StatusCode >= 400 { + logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url) + return + } - urlsToAuthenticate := extraUrlsFromSourceResponse(body, url) - for _, url := range urlsToAuthenticate { - feedCred := nugetFeedCredentials{ - url: url, - token: token, - username: username, - password: password, + urlsToAuthenticate := extraUrlsFromSourceResponse(body, url) + for _, discoveredURL := range urlsToAuthenticate { + feedCred := nugetFeedCredentials{ + url: discoveredURL, + token: token, + username: username, + password: password, + } + handler.credentials = append(handler.credentials, feedCred) + logging.RequestLogf(nil, " added url to authentication list: %s", discoveredURL) } - handler.credentials = append(handler.credentials, feedCred) - logging.RequestLogf(nil, " added url to authentication list: %s", url) - } + }() } } @@ -261,8 +273,8 @@ func (h *NugetFeedHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt return req, nil } - // Try OIDC credentials first - if oidc.TryAuthOIDCRequestWithPrefix(&h.mutex, h.oidcCredentials, req, ctx) { + // Try OIDC credentials first (HTTPS only to avoid leaking tokens over plaintext) + if req.URL.Scheme == "https" && h.oidcRegistry.TryAuth(req, ctx) { return req, nil } diff --git a/internal/handlers/oidc_handling_test.go b/internal/handlers/oidc_handling_test.go index dc742e6..e4c57e3 100644 --- a/internal/handlers/oidc_handling_test.go +++ b/internal/handlers/oidc_handling_test.go @@ -822,7 +822,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) { }, expectedLogLines: []string{ "registered aws OIDC credentials for nuget feed: https://nuget.example.com/index.json", - " registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages", + "registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages", }, urlsToAuthenticate: []string{ "https://nuget.example.com/index.json", // base url @@ -852,7 +852,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) { }, expectedLogLines: []string{ "registered azure OIDC credentials for nuget feed: https://nuget.example.com/index.json", - " registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages", + "registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages", }, urlsToAuthenticate: []string{ "https://nuget.example.com/index.json", // base url @@ -881,7 +881,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) { }, expectedLogLines: []string{ "registered jfrog OIDC credentials for nuget feed: https://jfrog.example.com/index.json", - " registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages", + "registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages", }, urlsToAuthenticate: []string{ "https://jfrog.example.com/index.json", // base url @@ -912,7 +912,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) { }, expectedLogLines: []string{ "registered cloudsmith OIDC credentials for nuget feed: https://cloudsmith.example.com/v3/index.json", - " registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages", + "registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages", }, urlsToAuthenticate: []string{ "https://cloudsmith.example.com/v3/index.json", // base url