Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions pkg/gateway/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ const (
DockerCatalogURL = "https://desktop.docker.com/mcp/catalog/v3/catalog.yaml"
catalogCacheFileName = "mcp_catalog.json"
fetchTimeout = 15 * time.Second
)

// catalogJSON is the URL we actually fetch (JSON is ~3x faster to parse than YAML).
var catalogJSON = strings.Replace(DockerCatalogURL, ".yaml", ".json", 1)
// catalogJSON is the URL we actually fetch (JSON is ~3x faster to parse than YAML).
catalogJSON = "https://desktop.docker.com/mcp/catalog/v3/catalog.json"
)

func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) {
server, err := ServerSpec(ctx, serverName)
Expand All @@ -40,8 +40,8 @@ func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) {
return server.Secrets, nil
}

func ServerSpec(ctx context.Context, serverName string) (Server, error) {
catalog, err := loadCatalog(ctx)
func ServerSpec(_ context.Context, serverName string) (Server, error) {
catalog, err := catalogOnce()
if err != nil {
return Server{}, err
}
Expand All @@ -54,6 +54,11 @@ func ServerSpec(ctx context.Context, serverName string) (Server, error) {
return server, nil
}

// ParseServerRef strips the optional "docker:" prefix from a server reference.
func ParseServerRef(ref string) string {
return strings.TrimPrefix(ref, "docker:")
}

// cachedCatalog is the on-disk cache format.
type cachedCatalog struct {
Catalog Catalog `json:"catalog"`
Expand All @@ -69,12 +74,6 @@ var catalogOnce = sync.OnceValues(func() (Catalog, error) {
return fetchAndCache(context.Background())
})

// loadCatalog returns the catalog, fetching it at most once per process run.
// On network failure it falls back to the disk cache.
func loadCatalog(_ context.Context) (Catalog, error) {
return catalogOnce()
}

// fetchAndCache tries to fetch the catalog from the network (using ETag for
// conditional requests) and falls back to the disk cache on any failure.
func fetchAndCache(ctx context.Context) (Catalog, error) {
Expand Down Expand Up @@ -128,16 +127,24 @@ func saveToDisk(path string, catalog Catalog, etag string) {
}

dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
slog.Warn("Failed to create MCP catalog cache directory", "error", err)
return
}

// Write to a temp file and rename so readers never see a partial file.
// Try creating the temp file first; only create the directory if needed.
tmp, err := os.CreateTemp(dir, ".mcp_catalog_*.tmp")
if err != nil {
slog.Warn("Failed to create MCP catalog temp file", "error", err)
return
if !os.IsNotExist(err) {
slog.Warn("Failed to create MCP catalog temp file", "error", err)
return
}
if mkErr := os.MkdirAll(dir, 0o755); mkErr != nil {
slog.Warn("Failed to create MCP catalog cache directory", "error", mkErr)
return
}
tmp, err = os.CreateTemp(dir, ".mcp_catalog_*.tmp")
if err != nil {
slog.Warn("Failed to create MCP catalog temp file", "error", err)
return
}
}
tmpName := tmp.Name()

Expand All @@ -159,6 +166,10 @@ func saveToDisk(path string, catalog Catalog, etag string) {
}
}

// catalogClient is a dedicated HTTP client for catalog fetches, isolated from
// http.DefaultClient so that other parts of the process cannot interfere.
var catalogClient = &http.Client{}

// fetchFromNetwork fetches the catalog, using the ETag for conditional requests.
// It returns (nil, "", nil) when the server responds with 304 Not Modified.
func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) {
Expand All @@ -174,7 +185,7 @@ func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error)
req.Header.Set("If-None-Match", etag)
}

resp, err := http.DefaultClient.Do(req)
resp, err := catalogClient.Do(req)
if err != nil {
return nil, "", err
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/gateway/catalog_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
package gateway

import (
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// testCatalog is a self-contained catalog used by all tests, removing the
// dependency on the live Docker MCP catalog and the network.
var testCatalog = Catalog{
"github-official": {
Type: "server",
Secrets: []Secret{
{Name: "github.personal_access_token", Env: "GITHUB_PERSONAL_ACCESS_TOKEN"},
},
},
"fetch": {
Type: "server",
},
"apify": {
Type: "remote",
Secrets: []Secret{
{Name: "apify.token", Env: "APIFY_TOKEN"},
},
Remote: Remote{
URL: "https://mcp.apify.com",
TransportType: "streamable-http",
},
},
}

func TestMain(m *testing.M) {
// Override the production catalogOnce so that tests never hit the network.
catalogOnce = func() (Catalog, error) {
return testCatalog, nil
}
os.Exit(m.Run())
}

func TestRequiredEnvVars_local(t *testing.T) {
secrets, err := RequiredEnvVars(t.Context(), "github-official")
require.NoError(t, err)
Expand Down Expand Up @@ -38,3 +71,15 @@ func TestServerSpec_remote(t *testing.T) {
assert.Equal(t, "https://mcp.apify.com", server.Remote.URL)
assert.Equal(t, "streamable-http", server.Remote.TransportType)
}

func TestServerSpec_notFound(t *testing.T) {
_, err := ServerSpec(t.Context(), "nonexistent")
require.Error(t, err)

assert.Contains(t, err.Error(), "not found in MCP catalog")
}

func TestParseServerRef(t *testing.T) {
assert.Equal(t, "github-official", ParseServerRef("docker:github-official"))
assert.Equal(t, "github-official", ParseServerRef("github-official"))
}
9 changes: 0 additions & 9 deletions pkg/gateway/servers.go

This file was deleted.

2 changes: 1 addition & 1 deletion pkg/teamloader/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon
envProvider,
)

return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, toolset.Config, envProvider, runConfig.WorkingDir)
return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, serverSpec.Secrets, toolset.Config, envProvider, runConfig.WorkingDir)

// STDIO MCP Server from shell command
case toolset.Command != "":
Expand Down
29 changes: 20 additions & 9 deletions pkg/tools/mcp/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ type GatewayToolset struct {

var _ tools.ToolSet = (*GatewayToolset)(nil)

func NewGatewayToolset(ctx context.Context, name, mcpServerName string, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) {
func NewGatewayToolset(ctx context.Context, name, mcpServerName string, secrets []gateway.Secret, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) {
slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName)

// Check which secrets (env vars) are required by the MCP server.
secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName)
if err != nil {
return nil, fmt.Errorf("reading which secrets the MCP server needs: %w", err)
}

// Make sure all the required secrets are available in the environment.
// TODO(dga): Ideally, the MCP gateway would use the same provider that we have.
fileSecrets, err := writeSecretsToFile(ctx, mcpServerName, secrets, envProvider)
Expand Down Expand Up @@ -66,7 +60,14 @@ func NewGatewayToolset(ctx context.Context, name, mcpServerName string, config a
}

func (t *GatewayToolset) Stop(ctx context.Context) error {
return errors.Join(t.Toolset.Stop(ctx), t.cleanUp())
stopErr := t.Toolset.Stop(ctx)

cleanUpErr := t.cleanUp()
if cleanUpErr != nil {
slog.Warn("Failed to clean up MCP Gateway temp files", "error", cleanUpErr)
}

return errors.Join(stopErr, cleanUpErr)
}

func writeSecretsToFile(ctx context.Context, mcpServerName string, secrets []gateway.Secret, envProvider environment.Provider) (string, error) {
Expand All @@ -77,6 +78,10 @@ func writeSecretsToFile(ctx context.Context, mcpServerName string, secrets []gat
return "", errors.New("missing environment variable " + secret.Env + " required by MCP server " + mcpServerName)
}

if strings.ContainsAny(v, "\n\r") {
return "", fmt.Errorf("secret %s contains newline characters", secret.Env)
}

secretValues = append(secretValues, fmt.Sprintf("%s=%s", secret.Name, v))
}

Expand All @@ -100,9 +105,15 @@ func writeTempFile(nameTemplate string, content []byte) (string, error) {
if err != nil {
return "", fmt.Errorf("creating temp file: %w", err)
}
defer f.Close()

if _, err := f.Write(content); err != nil {
f.Close()
os.Remove(f.Name())
return "", err
}

if err := f.Close(); err != nil {
os.Remove(f.Name())
return "", err
}

Expand Down
Loading