diff --git a/pkg/gateway/catalog.go b/pkg/gateway/catalog.go index b90b96a1b..ed7299a77 100644 --- a/pkg/gateway/catalog.go +++ b/pkg/gateway/catalog.go @@ -19,10 +19,10 @@ const ( DockerCatalogURL = "https://desktop.docker.com/mcp/catalog/v3/catalog.yaml" catalogCacheFileName = "mcp_catalog.json" fetchTimeout = 15 * time.Second -) -// catalogJSON is the URL we actually fetch (JSON is ~3x faster to parse than YAML). -var catalogJSON = strings.Replace(DockerCatalogURL, ".yaml", ".json", 1) + // catalogJSON is the URL we actually fetch (JSON is ~3x faster to parse than YAML). + catalogJSON = "https://desktop.docker.com/mcp/catalog/v3/catalog.json" +) func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) { server, err := ServerSpec(ctx, serverName) @@ -40,8 +40,8 @@ func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) { return server.Secrets, nil } -func ServerSpec(ctx context.Context, serverName string) (Server, error) { - catalog, err := loadCatalog(ctx) +func ServerSpec(_ context.Context, serverName string) (Server, error) { + catalog, err := catalogOnce() if err != nil { return Server{}, err } @@ -54,6 +54,11 @@ func ServerSpec(ctx context.Context, serverName string) (Server, error) { return server, nil } +// ParseServerRef strips the optional "docker:" prefix from a server reference. +func ParseServerRef(ref string) string { + return strings.TrimPrefix(ref, "docker:") +} + // cachedCatalog is the on-disk cache format. type cachedCatalog struct { Catalog Catalog `json:"catalog"` @@ -69,12 +74,6 @@ var catalogOnce = sync.OnceValues(func() (Catalog, error) { return fetchAndCache(context.Background()) }) -// loadCatalog returns the catalog, fetching it at most once per process run. -// On network failure it falls back to the disk cache. -func loadCatalog(_ context.Context) (Catalog, error) { - return catalogOnce() -} - // fetchAndCache tries to fetch the catalog from the network (using ETag for // conditional requests) and falls back to the disk cache on any failure. func fetchAndCache(ctx context.Context) (Catalog, error) { @@ -128,16 +127,24 @@ func saveToDisk(path string, catalog Catalog, etag string) { } dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - slog.Warn("Failed to create MCP catalog cache directory", "error", err) - return - } // Write to a temp file and rename so readers never see a partial file. + // Try creating the temp file first; only create the directory if needed. tmp, err := os.CreateTemp(dir, ".mcp_catalog_*.tmp") if err != nil { - slog.Warn("Failed to create MCP catalog temp file", "error", err) - return + if !os.IsNotExist(err) { + slog.Warn("Failed to create MCP catalog temp file", "error", err) + return + } + if mkErr := os.MkdirAll(dir, 0o755); mkErr != nil { + slog.Warn("Failed to create MCP catalog cache directory", "error", mkErr) + return + } + tmp, err = os.CreateTemp(dir, ".mcp_catalog_*.tmp") + if err != nil { + slog.Warn("Failed to create MCP catalog temp file", "error", err) + return + } } tmpName := tmp.Name() @@ -159,6 +166,10 @@ func saveToDisk(path string, catalog Catalog, etag string) { } } +// catalogClient is a dedicated HTTP client for catalog fetches, isolated from +// http.DefaultClient so that other parts of the process cannot interfere. +var catalogClient = &http.Client{} + // fetchFromNetwork fetches the catalog, using the ETag for conditional requests. // It returns (nil, "", nil) when the server responds with 304 Not Modified. func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) { @@ -174,7 +185,7 @@ func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) req.Header.Set("If-None-Match", etag) } - resp, err := http.DefaultClient.Do(req) + resp, err := catalogClient.Do(req) if err != nil { return nil, "", err } diff --git a/pkg/gateway/catalog_test.go b/pkg/gateway/catalog_test.go index c30341adb..f1d2cf040 100644 --- a/pkg/gateway/catalog_test.go +++ b/pkg/gateway/catalog_test.go @@ -1,12 +1,45 @@ package gateway import ( + "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// testCatalog is a self-contained catalog used by all tests, removing the +// dependency on the live Docker MCP catalog and the network. +var testCatalog = Catalog{ + "github-official": { + Type: "server", + Secrets: []Secret{ + {Name: "github.personal_access_token", Env: "GITHUB_PERSONAL_ACCESS_TOKEN"}, + }, + }, + "fetch": { + Type: "server", + }, + "apify": { + Type: "remote", + Secrets: []Secret{ + {Name: "apify.token", Env: "APIFY_TOKEN"}, + }, + Remote: Remote{ + URL: "https://mcp.apify.com", + TransportType: "streamable-http", + }, + }, +} + +func TestMain(m *testing.M) { + // Override the production catalogOnce so that tests never hit the network. + catalogOnce = func() (Catalog, error) { + return testCatalog, nil + } + os.Exit(m.Run()) +} + func TestRequiredEnvVars_local(t *testing.T) { secrets, err := RequiredEnvVars(t.Context(), "github-official") require.NoError(t, err) @@ -38,3 +71,15 @@ func TestServerSpec_remote(t *testing.T) { assert.Equal(t, "https://mcp.apify.com", server.Remote.URL) assert.Equal(t, "streamable-http", server.Remote.TransportType) } + +func TestServerSpec_notFound(t *testing.T) { + _, err := ServerSpec(t.Context(), "nonexistent") + require.Error(t, err) + + assert.Contains(t, err.Error(), "not found in MCP catalog") +} + +func TestParseServerRef(t *testing.T) { + assert.Equal(t, "github-official", ParseServerRef("docker:github-official")) + assert.Equal(t, "github-official", ParseServerRef("github-official")) +} diff --git a/pkg/gateway/servers.go b/pkg/gateway/servers.go deleted file mode 100644 index 458891777..000000000 --- a/pkg/gateway/servers.go +++ /dev/null @@ -1,9 +0,0 @@ -package gateway - -import ( - "strings" -) - -func ParseServerRef(ref string) string { - return strings.TrimPrefix(ref, "docker:") -} diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index d9275991e..59222c587 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -263,7 +263,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon envProvider, ) - return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, toolset.Config, envProvider, runConfig.WorkingDir) + return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, serverSpec.Secrets, toolset.Config, envProvider, runConfig.WorkingDir) // STDIO MCP Server from shell command case toolset.Command != "": diff --git a/pkg/tools/mcp/gateway.go b/pkg/tools/mcp/gateway.go index d3912a05e..2d914de89 100644 --- a/pkg/tools/mcp/gateway.go +++ b/pkg/tools/mcp/gateway.go @@ -22,15 +22,9 @@ type GatewayToolset struct { var _ tools.ToolSet = (*GatewayToolset)(nil) -func NewGatewayToolset(ctx context.Context, name, mcpServerName string, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) { +func NewGatewayToolset(ctx context.Context, name, mcpServerName string, secrets []gateway.Secret, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) { slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName) - // Check which secrets (env vars) are required by the MCP server. - secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName) - if err != nil { - return nil, fmt.Errorf("reading which secrets the MCP server needs: %w", err) - } - // Make sure all the required secrets are available in the environment. // TODO(dga): Ideally, the MCP gateway would use the same provider that we have. fileSecrets, err := writeSecretsToFile(ctx, mcpServerName, secrets, envProvider) @@ -66,7 +60,14 @@ func NewGatewayToolset(ctx context.Context, name, mcpServerName string, config a } func (t *GatewayToolset) Stop(ctx context.Context) error { - return errors.Join(t.Toolset.Stop(ctx), t.cleanUp()) + stopErr := t.Toolset.Stop(ctx) + + cleanUpErr := t.cleanUp() + if cleanUpErr != nil { + slog.Warn("Failed to clean up MCP Gateway temp files", "error", cleanUpErr) + } + + return errors.Join(stopErr, cleanUpErr) } func writeSecretsToFile(ctx context.Context, mcpServerName string, secrets []gateway.Secret, envProvider environment.Provider) (string, error) { @@ -77,6 +78,10 @@ func writeSecretsToFile(ctx context.Context, mcpServerName string, secrets []gat return "", errors.New("missing environment variable " + secret.Env + " required by MCP server " + mcpServerName) } + if strings.ContainsAny(v, "\n\r") { + return "", fmt.Errorf("secret %s contains newline characters", secret.Env) + } + secretValues = append(secretValues, fmt.Sprintf("%s=%s", secret.Name, v)) } @@ -100,9 +105,15 @@ func writeTempFile(nameTemplate string, content []byte) (string, error) { if err != nil { return "", fmt.Errorf("creating temp file: %w", err) } - defer f.Close() if _, err := f.Write(content); err != nil { + f.Close() + os.Remove(f.Name()) + return "", err + } + + if err := f.Close(); err != nil { + os.Remove(f.Name()) return "", err }