Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 86 additions & 74 deletions internal/handlers/nuget_feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/elazarl/goproxy"
Expand All @@ -36,9 +35,8 @@ type nugetV3IndexResponse struct {

// NugetFeedHandler handles requests to nuget feeds, adding auth.
type NugetFeedHandler struct {
credentials []nugetFeedCredentials
oidcCredentials map[string]*oidc.OIDCCredential
mutex sync.RWMutex
credentials []nugetFeedCredentials
oidcRegistry *oidc.OIDCRegistry
}

type nugetFeedCredentials struct {
Expand All @@ -52,8 +50,8 @@ type nugetFeedCredentials struct {
// NewNugetFeedHandler returns a new NugetFeedHandler.
func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
handler := NugetFeedHandler{
credentials: []nugetFeedCredentials{},
oidcCredentials: make(map[string]*oidc.OIDCCredential),
credentials: []nugetFeedCredentials{},
oidcRegistry: oidc.NewOIDCRegistry(),
}

httpClient := &http.Client{
Expand All @@ -72,58 +70,68 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
username := cred.GetString("username")
password := cred.GetString("password")

oidcCredential, _ := oidc.CreateOIDCCredential(cred)
if oidcCredential != nil {
key := url
if key == "" {
key = host
}
oidcCredential, _, ok := handler.oidcRegistry.Register(cred, []string{"url"}, "nuget feed")
if ok {
// Discover additional resource URLs from the nuget feed index.
// Host-only credentials (from the CLI) are still registered above
// for request-time matching, but discovery requires an absolute URL.
// Wrapped in a closure so defer runs promptly for each credential,
// ensuring the HTTP response body is always closed (pre-existing
// leak fixed here: the body was previously leaked on ReadAll error
// and on early-return status code paths).
if url != "" {
func() {
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
if err != nil {
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
return
}

if key != "" {
handler.oidcCredentials[key] = oidcCredential
logging.RequestLogf(nil, "registered %s OIDC credentials for nuget feed: %s", oidcCredential.Provider(), key)
if req.URL.Scheme != "https" {
logging.RequestLogf(nil, "refusing to discover nuget feed over non-https URL %s", url)
return
}

// now query all resources to add to the authentication list
req, err := http.NewRequestWithContext(context.Background(), "GET", key, nil)
if err != nil {
logging.RequestLogf(nil, "error creating http request (%s): %v", key, err)
continue
}
if !handler.oidcRegistry.TryAuth(req, nil) {
return
}

if oidc.TryAuthOIDCRequestWithPrefix(&handler.mutex, handler.oidcCredentials, req, nil) {
rawRsp, err := httpClient.Do(req)
if err != nil {
logging.RequestLogf(nil, "error retrieving http response (%s): %v", key, err)
continue
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
return
}
defer rawRsp.Body.Close()

body, err := io.ReadAll(rawRsp.Body)
if err != nil {
logging.RequestLogf(nil, "error reading http response body")
continue
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
return
}
rawRsp.Body.Close()

switch rawRsp.StatusCode {
case 401, 403:
logging.RequestLogf(nil, "unauthorized for nuget feed %s", key)
continue
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
return
}

if rawRsp.StatusCode >= 400 {
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, key)
continue
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
return
}

urlsToAuthenticate := extraUrlsFromSourceResponse(body, key)
for _, url := range urlsToAuthenticate {
handler.oidcCredentials[url] = oidcCredential
logging.RequestLogf(nil, " registered %s OIDC credentials for nuget resource: %s", oidcCredential.Provider(), url)
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
for _, discoveredURL := range urlsToAuthenticate {
handler.oidcRegistry.RegisterURL(discoveredURL, oidcCredential, "nuget resource")
}
}
}()
}
continue
}
// OIDC credentials are not used as static credentials.
if oidcCredential != nil {
continue
}

feedCred := nugetFeedCredentials{
url: url,
Expand All @@ -138,48 +146,52 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
// and authenticate them all
if url != "" {
logging.RequestLogf(nil, "fetching service index for nuget feed %s", url)
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
authenticateNugetRequest(req, feedCred, nil)
if err != nil {
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
continue
}
// Same closure pattern as the OIDC block above — ensures the
// HTTP response body is always closed via defer.
func() {
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
if err != nil {
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
return
}
authenticateNugetRequest(req, feedCred, nil)

rawRsp, err := httpClient.Do(req)
if err != nil {
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
continue
}
rawRsp, err := httpClient.Do(req)
if err != nil {
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
return
}
defer rawRsp.Body.Close()

body, err := io.ReadAll(rawRsp.Body)
if err != nil {
logging.RequestLogf(nil, "error reading http response body")
continue
}
rawRsp.Body.Close()
body, err := io.ReadAll(rawRsp.Body)
if err != nil {
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
return
}

switch rawRsp.StatusCode {
case 401, 403:
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
continue
}
switch rawRsp.StatusCode {
case 401, 403:
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
return
}

if rawRsp.StatusCode >= 400 {
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
continue
}
if rawRsp.StatusCode >= 400 {
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
return
}

urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
for _, url := range urlsToAuthenticate {
feedCred := nugetFeedCredentials{
url: url,
token: token,
username: username,
password: password,
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
for _, discoveredURL := range urlsToAuthenticate {
feedCred := nugetFeedCredentials{
url: discoveredURL,
token: token,
username: username,
password: password,
}
handler.credentials = append(handler.credentials, feedCred)
logging.RequestLogf(nil, " added url to authentication list: %s", discoveredURL)
}
handler.credentials = append(handler.credentials, feedCred)
logging.RequestLogf(nil, " added url to authentication list: %s", url)
}
}()
}
}

Expand Down Expand Up @@ -261,8 +273,8 @@ func (h *NugetFeedHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt
return req, nil
}

// Try OIDC credentials first
if oidc.TryAuthOIDCRequestWithPrefix(&h.mutex, h.oidcCredentials, req, ctx) {
// Try OIDC credentials first (HTTPS only to avoid leaking tokens over plaintext)
if req.URL.Scheme == "https" && h.oidcRegistry.TryAuth(req, ctx) {
return req, nil
}

Expand Down
8 changes: 4 additions & 4 deletions internal/handlers/oidc_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
},
expectedLogLines: []string{
"registered aws OIDC credentials for nuget feed: https://nuget.example.com/index.json",
" registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
"registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
},
urlsToAuthenticate: []string{
"https://nuget.example.com/index.json", // base url
Expand Down Expand Up @@ -852,7 +852,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
},
expectedLogLines: []string{
"registered azure OIDC credentials for nuget feed: https://nuget.example.com/index.json",
" registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
"registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
},
urlsToAuthenticate: []string{
"https://nuget.example.com/index.json", // base url
Expand Down Expand Up @@ -881,7 +881,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
},
expectedLogLines: []string{
"registered jfrog OIDC credentials for nuget feed: https://jfrog.example.com/index.json",
" registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
"registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
},
urlsToAuthenticate: []string{
"https://jfrog.example.com/index.json", // base url
Expand Down Expand Up @@ -912,7 +912,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
},
expectedLogLines: []string{
"registered cloudsmith OIDC credentials for nuget feed: https://cloudsmith.example.com/v3/index.json",
" registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
"registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
},
urlsToAuthenticate: []string{
"https://cloudsmith.example.com/v3/index.json", // base url
Expand Down
Loading