diff --git a/AGENTS.md b/AGENTS.md index b60c794..711d355 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,8 @@ Read this before changing `artifact-fs`. ## Non-obvious CLI/runtime behavior - `ARTIFACT_FS_ROOT` is the state root. `artifact-fs daemon --root` is the mount root. They are different things. -- `add-repo` is one-shot: register repo, clone blobless, build the initial snapshot, then exit. It does not mount FUSE or start background goroutines. +- `add-repo` is one-shot by default: register repo, clone blobless, build the initial snapshot, then exit. It does not mount FUSE or start background goroutines. +- `add-repo --async` only registers prepare state. The daemon mounts a gated placeholder, prepares clone/fetch and snapshot in the background, then opens the gate and starts watcher/refresh. - `daemon` is long-running: it mounts registered repos and starts watcher, refresh, and hydrator workers. - `git.CloneBlobless` already populates the git index with `read-tree HEAD`; be careful about extra index resets because they can discard staged state. diff --git a/README.md b/README.md index 38b5dee..a494531 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Quick start against a public repo: ```bash export ARTIFACT_FS_ROOT=/tmp/artifact-fs-test -# Register and clone (returns immediately) +# Register, clone, and build the initial snapshot ./artifact-fs add-repo \ --name workers-sdk \ --remote https://github.com/cloudflare/workers-sdk.git \ @@ -102,6 +102,44 @@ Use `--hydration-concurrency` to control the number of parallel blob-fetch worke ./artifact-fs daemon --root /tmp --hydration-concurrency 8 ``` +## Async repo preparation + +By default, `add-repo` waits for the blobless clone and initial snapshot before returning. Use `--async` when the daemon should prepare the repo in the background: + +```bash +./artifact-fs add-repo \ + --name workers-sdk \ + --remote https://github.com/cloudflare/workers-sdk.git \ + --branch main \ + --mount-root /tmp \ + --async +``` + +The daemon mounts a placeholder immediately. Operations inside that repo mount, such as `ls`, `less`, or `git -C /tmp/workers-sdk status`, wait until the clone/fetch and snapshot publish have completed. If preparation fails, those operations return an I/O error until preparation is retried: + +```bash +./artifact-fs status --name workers-sdk +./artifact-fs prepare --name workers-sdk +``` + +Async HTTPS remotes must use ambient credentials, such as a configured Git credential helper or repo-local Git config. Inline credentials in the remote URL are rejected for async repositories. + +For workflows that create the gitdir separately, `--prepared-gitdir` makes the async step fetch and prepare an existing gitdir instead of running `git clone`: + +```bash +git init --separate-git-dir /tmp/workers-sdk.git --initial-branch main /tmp/workers-sdk +git -C /tmp/workers-sdk remote add origin https://github.com/cloudflare/workers-sdk.git + +./artifact-fs add-repo \ + --name workers-sdk \ + --branch main \ + --mount-root /tmp \ + --async \ + --prepared-gitdir \ + --git-dir /tmp/workers-sdk.git \ + --fetch-ref main +``` + ## Sandboxes and Containers [`examples/Dockerfile`](examples/Dockerfile) builds artifact-fs and starts a FUSE-mounted repo inside a container. The container requires `--cap-add SYS_ADMIN --device /dev/fuse` for FUSE access. @@ -129,7 +167,7 @@ On hosts with AppArmor enabled (Ubuntu default), add `--security-opt apparmor:un ## Architecture -ArtifactFS has two distinct phases: a one-shot **setup** (`add-repo`) that performs a fast blobless clone of a repo, and a long-running **daemon** that mounts it via FUSE and serves file operations. +ArtifactFS has two distinct phases: a one-shot **setup** (`add-repo`) that registers and usually prepares a fast blobless clone, and a long-running **daemon** that mounts it via FUSE and serves file operations. With `add-repo --async`, setup only registers the repo; the daemon performs clone/fetch and snapshot publishing while FUSE operations wait behind a readiness gate. ``` ┌─────────────────────────────────────────────────┐ @@ -174,7 +212,7 @@ ArtifactFS has two distinct phases: a one-shot **setup** (`add-repo`) that perfo ### Data flow -1. **Clone** -- `add-repo` runs `git clone --filter=blob:none` (blobless). Only commits, trees, and refs are fetched. No file content is downloaded. +1. **Clone/fetch** -- `add-repo` runs `git clone --filter=blob:none` (blobless) unless `--async` is used. In async mode, the daemon performs either the blobless clone or a fetch into a prepared gitdir. Only commits, trees, and refs are fetched. No file content is downloaded. 2. **Index** -- `git ls-tree -r -t -z HEAD` enumerates every path in the tree. Sizes are resolved locally via `git cat-file --batch-check` with `GIT_NO_LAZY_FETCH=1` to avoid network round-trips. The result is bulk-inserted into a SQLite `base_nodes` table as a new generation. diff --git a/e2e_async_test.go b/e2e_async_test.go new file mode 100644 index 0000000..7f49237 --- /dev/null +++ b/e2e_async_test.go @@ -0,0 +1,232 @@ +//go:build !windows + +package main + +import ( + "context" + "log/slog" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/cloudflare/artifact-fs/internal/daemon" + "github.com/cloudflare/artifact-fs/internal/logging" + "github.com/cloudflare/artifact-fs/internal/model" +) + +func TestE2EAsyncPreparedGitDirBlocksUntilReady(t *testing.T) { + if os.Getenv("AFS_RUN_E2E_TESTS") != "1" { + t.Skip("skipping e2e tests (set AFS_RUN_E2E_TESTS=1 to run)") + } + skipIfNoFUSE(t) + + remoteURL := os.Getenv("AFS_E2E_REPO") + if remoteURL == "" { + remoteURL = createLocalTestRepo(t) + } + preparedGitDir, preparedWorktree := createPreparedGitDir(t, remoteURL) + _ = preparedWorktree + + unblock := filepath.Join(t.TempDir(), "unblock-fetch") + installBlockingGitFetchWrapper(t, unblock) + + repo := newAsyncPreparedE2ERepo(t, preparedGitDir, "main") + + waitForCondition(t, 10*time.Second, "async repo preparing", func() (bool, string) { + st, err := repo.svc.Status(context.Background(), repoName) + if err != nil { + return false, err.Error() + } + if st.State == model.PrepareStatePreparing { + return true, "" + } + return false, "state=" + st.State + }) + + done := make(chan error, 1) + go func() { + entries, err := os.ReadDir(repo.mountPath) + if err == nil && len(entries) == 0 { + err = os.ErrNotExist + } + done <- err + }() + + select { + case err := <-done: + t.Fatalf("ReadDir returned before async prepare was released: %v", err) + case <-time.After(500 * time.Millisecond): + } + + if err := os.WriteFile(unblock, []byte("go\n"), 0o644); err != nil { + t.Fatal(err) + } + select { + case err := <-done: + if err != nil { + t.Fatalf("ReadDir after prepare release: %v", err) + } + case <-time.After(30 * time.Second): + t.Fatal("ReadDir did not unblock after async prepare completed") + } + + entries := lsDir(t, repo.mountPath) + assertContains(t, entries, ".git") + assertContains(t, entries, "README.md") + assertGitStatus(t, repo.mountPath, map[string]string{}) +} + +func TestE2EAsyncPreparedGitDirFailureThenRetry(t *testing.T) { + if os.Getenv("AFS_RUN_E2E_TESTS") != "1" { + t.Skip("skipping e2e tests (set AFS_RUN_E2E_TESTS=1 to run)") + } + skipIfNoFUSE(t) + + remoteURL := os.Getenv("AFS_E2E_REPO") + if remoteURL == "" { + remoteURL = createLocalTestRepo(t) + } + preparedGitDir, preparedWorktree := createPreparedGitDir(t, "file://"+filepath.Join(t.TempDir(), "missing.git")) + repo := newAsyncPreparedE2ERepo(t, preparedGitDir, "main") + + waitForCondition(t, 10*time.Second, "async prepare failure", func() (bool, string) { + st, err := repo.svc.Status(context.Background(), repoName) + if err != nil { + return false, err.Error() + } + if st.State == model.PrepareStateFailed && st.PrepareError != "" { + return true, "" + } + return false, "state=" + st.State + " prepare_error=" + st.PrepareError + }) + + if _, err := os.ReadDir(repo.mountPath); err == nil { + t.Fatal("ReadDir unexpectedly succeeded after prepare failure") + } + + gitCmd(t, preparedWorktree, "remote", "set-url", "origin", remoteURL) + if err := repo.svc.Prepare(context.Background(), repoName); err != nil { + t.Fatalf("prepare retry: %v", err) + } + + waitForCondition(t, 30*time.Second, "async prepare retry ready", func() (bool, string) { + st, err := repo.svc.Status(context.Background(), repoName) + if err != nil { + return false, err.Error() + } + if st.State == "mounted" && st.PrepareError == "" { + return true, "" + } + return false, "state=" + st.State + " prepare_error=" + st.PrepareError + }) + + entries := lsDir(t, repo.mountPath) + assertContains(t, entries, "README.md") + assertGitStatus(t, repo.mountPath, map[string]string{}) +} + +func createPreparedGitDir(t *testing.T, remoteURL string) (gitDir string, worktree string) { + t.Helper() + tmp := t.TempDir() + gitDir = filepath.Join(tmp, "prepared.git") + worktree = filepath.Join(tmp, "prepared") + run(t, "", "git", "init", "--separate-git-dir", gitDir, "--initial-branch", "main", worktree) + run(t, worktree, "git", "remote", "add", "origin", remoteURL) + return gitDir, worktree +} + +func installBlockingGitFetchWrapper(t *testing.T, unblockPath string) { + t.Helper() + realGit, err := exec.LookPath("git") + if err != nil { + t.Fatal(err) + } + wrapperDir := t.TempDir() + wrapperPath := filepath.Join(wrapperDir, "git") + script := `#!/bin/sh +for arg in "$@"; do + if [ "$arg" = "fetch" ]; then + while [ ! -f "$AFS_ASYNC_GIT_UNBLOCK" ]; do + sleep 0.05 + done + break + fi +done +exec "$AFS_REAL_GIT" "$@" +` + if err := os.WriteFile(wrapperPath, []byte(script), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("AFS_REAL_GIT", realGit) + t.Setenv("AFS_ASYNC_GIT_UNBLOCK", unblockPath) + t.Setenv("PATH", wrapperDir+string(os.PathListSeparator)+os.Getenv("PATH")) +} + +func newAsyncPreparedE2ERepo(t *testing.T, preparedGitDir string, fetchRef string) *mountedE2ERepo { + t.Helper() + root, err := os.MkdirTemp("", "artifact-fs-e2e-async-root-*") + if err != nil { + t.Fatal(err) + } + mountDir, err := os.MkdirTemp("", "artifact-fs-e2e-async-mount-*") + if err != nil { + _ = os.RemoveAll(root) + t.Fatal(err) + } + mountPath := filepath.Join(mountDir, repoName) + if err := os.MkdirAll(mountPath, 0o755); err != nil { + _ = os.RemoveAll(mountDir) + _ = os.RemoveAll(root) + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + logger := logging.NewJSONLogger(os.Stderr, slog.LevelWarn) + svc, err := daemon.New(ctx, root, logger) + if err != nil { + cancel() + t.Fatal(err) + } + svc.SetMountRoot(mountDir) + + cfg := model.RepoConfig{ + Name: repoName, + ID: model.RepoID(repoName), + Branch: "main", + RefreshInterval: 5 * time.Minute, + MountRoot: mountDir, + GitDir: preparedGitDir, + PreparedGitDir: true, + FetchRef: fetchRef, + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, daemon.AddRepoOptions{Async: true}); err != nil { + cancel() + _ = svc.Close() + t.Fatalf("add-repo async prepared-gitdir: %v", err) + } + + errCh := make(chan error, 1) + go func() { errCh <- svc.Start(ctx) }() + + if !waitForMount(t, mountPath, 60*time.Second) { + cancel() + _ = svc.Close() + t.Fatal("FUSE mount did not appear within timeout") + } + + repo := &mountedE2ERepo{ + root: root, + mountDir: mountDir, + mountPath: mountPath, + svc: svc, + cancel: cancel, + errCh: errCh, + } + t.Cleanup(func() { + repo.close(t) + }) + return repo +} diff --git a/e2e_setup_darwin_test.go b/e2e_setup_darwin_test.go index 5b63116..cfc5950 100644 --- a/e2e_setup_darwin_test.go +++ b/e2e_setup_darwin_test.go @@ -8,12 +8,24 @@ import ( "testing" ) -// skipIfNoFUSE skips the test if macFUSE is not installed. +// skipIfNoFUSE skips the test if macFUSE is not installed or the +// mount helper cannot be executed by this test process. Some sandboxed +// environments expose the macFUSE bundle but deny exec of mount_macfuse; +// without this preflight every e2e case waits for its mount timeout. func skipIfNoFUSE(t *testing.T) { t.Helper() if _, err := os.Stat("/Library/Filesystems/macfuse.fs"); err != nil { t.Skip("skipping: macFUSE not installed") } + helper := "/Library/Filesystems/macfuse.fs/Contents/Resources/mount_macfuse" + cmd := exec.Command(helper, "--help") + out, err := cmd.CombinedOutput() + if err == nil { + return + } + if strings.Contains(err.Error(), "operation not permitted") || strings.Contains(string(out), "operation not permitted") { + t.Skipf("skipping: macFUSE mount helper is not executable in this environment: %v", err) + } } // configGitSafeDir adds safe.directory entries for the mount path. diff --git a/internal/auth/redact.go b/internal/auth/redact.go index 8531345..7e5ff76 100644 --- a/internal/auth/redact.go +++ b/internal/auth/redact.go @@ -1,6 +1,7 @@ package auth import ( + "net" "net/url" "regexp" "strings" @@ -14,7 +15,10 @@ func RedactRemoteURL(raw string) string { } u, err := url.Parse(raw) if err != nil { - return tokenLike.ReplaceAllString(raw, `$1=REDACTED`) + return redactMalformedURL(raw) + } + if u.User == nil && strings.Contains(raw, "@") && (isMalformedHTTPUserinfo(raw, u) || schemeLessUserinfoStart(raw) >= 0) { + return redactMalformedURL(raw) } if u.User != nil { username := u.User.Username() @@ -22,26 +26,462 @@ func RedactRemoteURL(raw string) string { u.User = url.User("REDACTED") } } - if u.RawQuery != "" { - u.RawQuery = tokenLike.ReplaceAllString(u.RawQuery, `$1=REDACTED`) + if u.RawQuery != "" || u.ForceQuery { + u.RawQuery = "REDACTED" + } + if u.Fragment != "" || strings.Contains(raw, "#") { + u.Fragment = "REDACTED" } return u.String() } +func redactMalformedURL(raw string) string { + redacted := raw + authorityStart := malformedAuthorityStart(redacted) + if authorityStart < 0 { + return redactQueryFragment(tokenLike.ReplaceAllString(redacted, `$1=REDACTED`)) + } + userinfoEnd := len(redacted) + if relEnd := strings.IndexAny(redacted[authorityStart:], "?#"); relEnd >= 0 { + userinfoEnd = authorityStart + relEnd + } + if at := strings.LastIndex(redacted[authorityStart:userinfoEnd], "@"); at >= 0 { + redacted = redacted[:authorityStart] + "REDACTED" + redacted[authorityStart+at:] + } else if userinfoEnd < len(redacted) && strings.Contains(redacted[userinfoEnd:], "@") && userinfoLikeBeforePath(redacted[authorityStart:userinfoEnd]) { + redacted = redacted[:authorityStart] + "REDACTED" + redacted[userinfoEnd:] + } + redacted = tokenLike.ReplaceAllString(redacted, `$1=REDACTED`) + return redactQueryFragment(redacted) +} + +func redactQueryFragment(raw string) string { + q := strings.Index(raw, "?") + f := strings.Index(raw, "#") + if q >= 0 && (f < 0 || q < f) { + if f >= 0 { + return raw[:q+1] + "REDACTED#REDACTED" + } + return raw[:q+1] + "REDACTED" + } + if f >= 0 { + return raw[:f+1] + "REDACTED" + } + return raw +} + +func HasInlineCredentials(raw string) bool { + if strings.ContainsAny(raw, "?#") { + return true + } + urlLike := strings.Contains(raw, "://") + u, err := url.Parse(raw) + if err != nil { + if malformedAuthorityStart(raw) >= 0 && strings.Contains(raw, "@") { + return true + } + if urlLike && strings.ContainsAny(raw, "@?#") { + return true + } + return tokenLike.MatchString(raw) + } + if u.User == nil && strings.Contains(raw, "@") && (isMalformedHTTPUserinfo(raw, u) || schemeLessUserinfoStart(raw) >= 0) { + return true + } + if u.RawQuery != "" || u.ForceQuery || u.Fragment != "" || strings.Contains(raw, "#") { + return true + } + if tokenLike.MatchString(u.RawQuery) { + return true + } + if u.User == nil { + return false + } + username := u.User.Username() + _, hasPassword := u.User.Password() + switch strings.ToLower(u.Scheme) { + case "http", "https": + return username != "" || hasPassword + case "ssh": + return hasPassword || tokenLikeUsername(username) + default: + return username != "" || hasPassword + } +} + +func malformedAuthorityStart(raw string) int { + lower := strings.ToLower(strings.TrimSpace(raw)) + if i := strings.Index(lower, "://"); i >= 0 { + return i + len("://") + } + if schemeLessUserinfoStart(raw) >= 0 { + return 0 + } + for _, prefix := range []string{"https://", "http://", "ssh://", "git://", "https:/", "http:/", "ssh:/", "git:/", "https//", "http//", "ssh//", "git//", "https:", "http:", "ssh:", "git:"} { + if strings.HasPrefix(lower, prefix) { + return len(prefix) + } + } + return -1 +} + +func isHTTPRemoteLike(raw string, scheme string) bool { + switch strings.ToLower(scheme) { + case "http", "https": + return true + case "": + return malformedAuthorityStart(raw) >= 0 + default: + return false + } +} + +func isMalformedHTTPUserinfo(raw string, u *url.URL) bool { + if !isHTTPRemoteLike(raw, u.Scheme) { + return false + } + if u.Host == "" { + return true + } + if strings.HasPrefix(u.Path, "/@") { + return true + } + if rawUsernameBeforeDelimiter(raw) { + return true + } + if strings.Contains(u.Path, "@") { + if slashAfterPathAt(u.Path) { + host := u.Hostname() + return !(isRealHost(host) && !isUserLikeHost(host) && pathBeforeAtSegmentCount(u.Path) > 1) + } + return !validParsedHostPort(u) + } + if !rawUserinfoCandidateHasPassword(raw) { + return false + } + pathBeforeAt := u.Path + if at := strings.Index(pathBeforeAt, "@"); at >= 0 { + pathBeforeAt = pathBeforeAt[:at] + } + if strings.Count(strings.Trim(pathBeforeAt, "/"), "/") > 0 { + host := u.Hostname() + if isRealHost(host) && !isUserLikeHost(host) && !slashAfterPathAt(u.Path) { + return false + } + } + return true +} + +func isRealHost(host string) bool { + return strings.Contains(host, ".") || host == "localhost" || net.ParseIP(host) != nil +} + +func validParsedHostPort(u *url.URL) bool { + if u.Host == "" { + return false + } + if strings.Contains(u.Path, "@") && slashAfterPathAt(u.Path) { + return false + } + host := u.Hostname() + if strings.HasPrefix(u.Host, "[") { + return true + } + if isUserLikeHost(host) { + return false + } + if u.Port() != "" { + return isRealHost(host) || isKnownSingleLabelGitHost(host) + } + if strings.Contains(u.Path, "@") && !isRealHost(host) && !isKnownSingleLabelGitHost(host) { + return false + } + return !strings.Contains(u.Host, ":") || u.Port() != "" +} + +func isUserLikeHost(host string) bool { + return host == "user" || host == "user.name" || host == "first.last" +} + +func tokenLikeUsername(username string) bool { + lower := strings.ToLower(username) + return strings.HasPrefix(lower, "ghp_") || + strings.HasPrefix(lower, "github_pat_") || + strings.HasPrefix(lower, "glpat-") || + strings.HasPrefix(lower, "gho_") || + strings.HasPrefix(lower, "ghu_") || + strings.HasPrefix(lower, "ghs_") || + strings.HasPrefix(lower, "ghr_") || + lower == "x-token-auth" || + lower == "oauth2" +} + +func isKnownSingleLabelGitHost(host string) bool { + return host == "git" || host == "ghe" +} + +func slashAfterPathAt(path string) bool { + at := strings.Index(path, "@") + return at >= 0 && strings.Contains(path[at:], "/") +} + +func pathBeforeAtSegmentCount(path string) int { + if at := strings.Index(path, "@"); at >= 0 { + path = path[:at] + } + count := 0 + for _, segment := range strings.Split(strings.Trim(path, "/"), "/") { + if segment != "" { + count++ + } + } + return count +} + +func rawUsernameBeforeDelimiter(raw string) bool { + authorityStart := malformedAuthorityStart(raw) + if authorityStart < 0 { + return false + } + at := strings.LastIndex(raw[authorityStart:], "@") + if at < 0 { + return false + } + at += authorityStart + if relEnd := strings.IndexAny(raw[authorityStart:at], "?#"); relEnd >= 0 { + candidate := raw[authorityStart : authorityStart+relEnd] + return userinfoLikeBeforePath(candidate) + } + return false +} + +func userinfoLikeBeforePath(candidate string) bool { + if candidate == "" { + return false + } + if slash := strings.Index(candidate, "/"); slash >= 0 { + prefix := candidate[:slash] + return strings.Contains(prefix, ":") || !isRealHost(prefix) + } + return true +} + +func schemeLessUserinfoStart(raw string) int { + if strings.Contains(raw, "://") { + return -1 + } + if isSCPStyleRemote(raw) { + return -1 + } + end := len(raw) + if relEnd := strings.IndexAny(raw, "/?#"); relEnd >= 0 { + end = relEnd + } + if end == 0 { + return -1 + } + prefix := raw[:end] + at := strings.LastIndex(prefix, "@") + colon := strings.Index(prefix, ":") + if colon >= 0 && (at > colon || strings.Contains(raw[end:], "@")) { + return 0 + } + return -1 +} + +func isSCPStyleRemote(raw string) bool { + if strings.Contains(raw, "://") { + return false + } + end := len(raw) + if relEnd := strings.IndexAny(raw, "/?#"); relEnd >= 0 { + end = relEnd + } + prefix := raw[:end] + at := strings.Index(prefix, "@") + colon := strings.Index(prefix, ":") + return at > 0 && colon > at +} + +func rawUserinfoCandidateHasPassword(raw string) bool { + authorityStart := malformedAuthorityStart(raw) + if authorityStart < 0 { + return false + } + userinfoEnd := len(raw) + if relEnd := strings.IndexAny(raw[authorityStart:], "?#"); relEnd >= 0 { + userinfoEnd = authorityStart + relEnd + } + at := strings.LastIndex(raw[authorityStart:userinfoEnd], "@") + if at < 0 { + return false + } + candidate := raw[authorityStart : authorityStart+at] + return strings.Contains(candidate, ":") +} + func RedactString(s string) string { if s == "" { return "" } - s = tokenLike.ReplaceAllString(s, `$1=REDACTED`) // Redact any URL-shaped substring with credentials (not just those with @) - if strings.Contains(s, "://") { + if strings.Contains(s, "://") || strings.ContainsAny(s, "?#@") { parts := strings.Split(s, " ") for i := range parts { - if strings.Contains(parts[i], "://") { - parts[i] = RedactRemoteURL(parts[i]) + if shouldRedactRemoteToken(parts[i]) { + parts[i] = redactRemoteToken(parts[i]) } } s = strings.Join(parts, " ") } + s = tokenLike.ReplaceAllString(s, `$1=REDACTED`) return s } + +func shouldRedactRemoteToken(token string) bool { + return strings.Contains(token, "://") || containsRemoteMarker(token) || schemeLessUserinfoStart(token) >= 0 || isSCPStyleRemote(token) +} + +func redactRemoteToken(token string) string { + start := strings.IndexFunc(token, func(r rune) bool { + return !strings.ContainsRune("'\"([{<", r) + }) + if start < 0 { + return token + } + end := strings.LastIndexFunc(token, func(r rune) bool { + return !strings.ContainsRune("'\")]}>,.:", r) + }) + if end < start { + return token + } + core := token[start : end+1] + return token[:start] + redactRemoteCore(core) + token[end+1:] +} + +func redactRemoteCore(core string) string { + var b strings.Builder + for len(core) > 0 { + sep := strings.IndexAny(core, ",;") + if sep < 0 { + b.WriteString(redactSingleRemoteCore(core)) + break + } + next := strings.TrimLeft(core[sep+1:], " ") + if separatorInsideUserinfo(core, sep) || !hasSplitRemoteMarker(next) { + b.WriteString(redactSingleRemoteCore(core)) + break + } + b.WriteString(redactSingleRemoteCore(core[:sep])) + b.WriteByte(core[sep]) + core = core[sep+1:] + } + return b.String() +} + +func separatorInsideUserinfo(core string, sep int) bool { + remoteStart := remoteStartIndex(core[:sep]) + if strings.Contains(core[remoteStart:sep], "@") { + return false + } + authorityStart := authorityStartInCore(core, remoteStart) + if authorityStart < 0 { + return false + } + candidate := core[authorityStart:sep] + if !userinfoLikeBeforePath(candidate) { + return false + } + if completeURLBeforeSeparator(core[remoteStart:sep]) { + return false + } + nextBoundary := len(core) + if relEnd := strings.IndexAny(core[sep+1:], " "); relEnd >= 0 { + nextBoundary = sep + 1 + relEnd + } + if at := strings.Index(core[sep+1:nextBoundary], "@"); at >= 0 && strings.Contains(core[sep+1:sep+1+at], ":") { + return true + } + return strings.Contains(core[sep+1:nextBoundary], "@") +} + +func completeURLBeforeSeparator(raw string) bool { + u, err := url.Parse(raw) + if err != nil || u.Scheme == "" || u.Host == "" || u.User != nil { + return false + } + if strings.HasPrefix(u.Host, "[") { + return true + } + host := u.Hostname() + return isRealHost(host) || isKnownSingleLabelGitHost(host) +} + +func authorityStartInCore(core string, remoteStart int) int { + rest := strings.ToLower(core[remoteStart:]) + for _, marker := range []string{"://", ":/", "//", ":"} { + if i := strings.Index(rest, marker); i >= 0 { + return remoteStart + i + len(marker) + } + } + return -1 +} + +func hasSeparateRemoteMarker(s string) bool { + lower := strings.ToLower(s) + for _, marker := range []string{"https://", "http://", "ssh://", "git://"} { + if strings.HasPrefix(lower, marker) { + return true + } + } + return false +} + +func hasSplitRemoteMarker(s string) bool { + return hasSeparateRemoteMarker(s) || hasRemoteMarker(s) || schemeLessUserinfoStart(s) >= 0 || isSCPStyleRemote(s) +} + +func hasRemoteMarker(s string) bool { + if strings.Contains(s, "://") { + return true + } + lower := strings.ToLower(s) + for _, marker := range []string{"https:/", "http:/", "ssh:/", "git:/", "https//", "http//", "ssh//", "git//", "https:", "http:", "ssh:", "git:"} { + if strings.HasPrefix(lower, marker) { + return true + } + } + return false +} + +func containsRemoteMarker(s string) bool { + lower := strings.ToLower(s) + for _, marker := range []string{"https:/", "http:/", "ssh:/", "git:/", "https//", "http//", "ssh//", "git//", "https:", "http:", "ssh:", "git:"} { + if strings.Contains(lower, marker) { + return true + } + } + return false +} + +func redactSingleRemoteCore(core string) string { + remoteStart := remoteStartIndex(core) + return core[:remoteStart] + RedactRemoteURL(core[remoteStart:]) +} + +func remoteStartIndex(s string) int { + best := -1 + for _, marker := range []string{"https://", "http://", "ssh://", "git://", "https:/", "http:/", "ssh:/", "git:/", "https//", "http//", "ssh//", "git//", "https:", "http:", "ssh:", "git:"} { + if i := strings.Index(strings.ToLower(s), marker); i >= 0 && (best < 0 || i < best) { + best = i + } + } + if best >= 0 { + return best + } + if at := strings.Index(s, "@"); at >= 0 && strings.ContainsAny(s[at:], "?#") { + if eq := strings.LastIndex(s[:at], "="); eq >= 0 { + return eq + 1 + } + } + return 0 +} diff --git a/internal/auth/redact_test.go b/internal/auth/redact_test.go index db0629c..17f3573 100644 --- a/internal/auth/redact_test.go +++ b/internal/auth/redact_test.go @@ -11,6 +11,548 @@ func TestRedactRemoteURL(t *testing.T) { if containsAny(out, []string{"token123", "abc"}) { t.Fatalf("token leaked in output: %s", out) } + + out = RedactRemoteURL("https://github.com/org/repo.git?sig=s3cr3t") + if containsAny(out, []string{"sig", "s3cr3t"}) { + t.Fatalf("query secret leaked in output: %s", out) + } + + out = RedactRemoteURL("https://github.com/org/repo.git#access_token=ghp_secret") + if containsAny(out, []string{"access_token", "ghp_secret"}) { + t.Fatalf("fragment secret leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret%zz@github.com/org/repo.git") + if containsAny(out, []string{"ghp_secret", "%zz"}) { + t.Fatalf("malformed userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user:pa#ss@example.com/org/repo.git") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("malformed userinfo with fragment delimiter leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret#@github.com/org/repo.git") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("malformed username-only userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user:pa/ss@example.com/org/repo.git") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("malformed userinfo with path delimiter leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret%zz:password=abc@example.com/org/repo.git") + if containsAny(out, []string{"ghp_secret", "%zz", "abc"}) { + t.Fatalf("malformed userinfo with token-like password leaked in output: %s", out) + } + + out = RedactRemoteURL("git@github.com:org/repo.git?sig=s3cr3t") + if containsAny(out, []string{"sig", "s3cr3t"}) { + t.Fatalf("scp-style query secret leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret/@github.com/org/repo.git") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https:/user:ghp_secret@github.com/org/repo.git") + if containsAny(out, []string{"ghp_secret", "user"}) { + t.Fatalf("missing-slash https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user:123/ss@example.com/org/repo.git") + if containsAny(out, []string{"user", "123", "ss"}) { + t.Fatalf("numeric-prefix https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user:123/dir/ss@example.com/org/repo.git") + if containsAny(out, []string{"user", "123", "ss"}) { + t.Fatalf("multi-segment https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user.name:pa/ss@example.com/org/repo.git") + if containsAny(out, []string{"user.name", "pa", "ss"}) { + t.Fatalf("dotted malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user.name:123/dir/ss@example.com/org/repo.git") + if containsAny(out, []string{"user.name", "123", "ss"}) { + t.Fatalf("dotted numeric malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://first.last:123/dir/ss@example.com/org/repo.git") + if containsAny(out, []string{"first.last", "123", "ss"}) { + t.Fatalf("dotted name numeric malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://john.doe:123/ghp_secret@example.com/org/repo.git") + if containsAny(out, []string{"john.doe", "123", "ghp_secret"}) { + t.Fatalf("dotted name token malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret/part@example.com/org/repo.git") + if containsAny(out, []string{"ghp_secret", "part"}) { + t.Fatalf("path-split username-only malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret/part?x@example.com/org/repo.git") + if containsAny(out, []string{"ghp_secret", "part"}) { + t.Fatalf("query-split username-only malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret/part@example.com") + if containsAny(out, []string{"ghp_secret", "part"}) { + t.Fatalf("single-label path username-only malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://ghp_secret:8443/part@example.com") + if containsAny(out, []string{"ghp_secret", "8443", "part"}) { + t.Fatalf("ported single-label path username-only malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://git.example.com/team/repo@2026/archive.git") + if out != "https://git.example.com/team/repo@2026/archive.git" { + t.Fatalf("real-host path with at sign and continuation was redacted: %s", out) + } + + out = RedactRemoteURL("file:///tmp/repo@2026.git") + if out != "file:///tmp/repo@2026.git" { + t.Fatalf("file URL path with at sign was redacted: %s", out) + } + + out = RedactRemoteURL("https://ghp.secret/part@example.com/org/repo.git") + if containsAny(out, []string{"ghp.secret", "part"}) { + t.Fatalf("dotted path-split username-only malformed https userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https//ghp_secret@github.com/org/repo.git") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("missing-colon https userinfo leaked in output: %s", out) + } + + out = RedactString("fatal: https://ghp_secret%zz:password=abc@example.com/org/repo.git failed") + if containsAny(out, []string{"ghp_secret", "%zz", "abc"}) { + t.Fatalf("redacted string leaked malformed credentials: %s", out) + } + + out = RedactString("fatal: unable to access 'https://user:pass@example.com/org/repo.git/': failed") + if containsAny(out, []string{"user", "pass"}) { + t.Fatalf("quoted URL leaked credentials: %s", out) + } + + out = RedactString("fatal: git@github.com:org/repo.git?sig=s3cr3t failed") + if containsAny(out, []string{"sig", "s3cr3t"}) { + t.Fatalf("scp-style query leaked in redacted string: %s", out) + } + + out = RedactRemoteURL("ssh://git:sec%zzret@example.com/org/repo.git") + if containsAny(out, []string{"sec", "%zz", "ret"}) { + t.Fatalf("malformed ssh userinfo leaked in output: %s", out) + } + + out = RedactString("fatal: ssh:/git:pa/ss@github.com/org/repo.git failed") + if containsAny(out, []string{"pa/ss"}) { + t.Fatalf("malformed ssh-style userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: remote=ssh:/git:pa/ss@github.com/org/repo.git failed") + if containsAny(out, []string{"pa/ss"}) { + t.Fatalf("prefixed malformed ssh-style userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https:/user:ghp_secret@github.com/org/repo.git failed") + if containsAny(out, []string{"ghp_secret", "user"}) { + t.Fatalf("http-like typo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: remote=https:/user:ghp_secret@github.com/org/repo.git failed") + if containsAny(out, []string{"ghp_secret", "user"}) { + t.Fatalf("prefixed http-like typo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: remote=https//ghp_secret@github.com/org/repo.git failed") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("prefixed missing-colon typo leaked in redacted string: %s", out) + } + + out = RedactString("remotes=ssh://git@github.com/org/repo.git,https://ghp_secret@example.com/private.git") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("second URL credential leaked in redacted string: %s", out) + } + + out = RedactString("remotes=https://git.example.com:8443/team/repo.git,https://ghp_secret@example.com/private.git") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("second URL credential after ported URL leaked in redacted string: %s", out) + } + + out = RedactString("remotes=https://git:8443/team/repo.git,https://ghp_secret@example.com/private.git") + if containsAny(out, []string{"ghp_secret"}) { + t.Fatalf("second URL credential after single-label ported URL leaked in redacted string: %s", out) + } + + out = RedactString("remotes=https://git.example.com/repo.git,alice:ghp_secret@github.com:org/repo.git") + if containsAny(out, []string{"alice", "ghp_secret"}) { + t.Fatalf("scheme-less second URL credential leaked in redacted string: %s", out) + } + + out = RedactString("remotes=https://git.example.com/repo.git;alice:ghp_secret@github.com:org/repo.git") + if containsAny(out, []string{"alice", "ghp_secret"}) { + t.Fatalf("scheme-less second URL credential after semicolon leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://github.com/org/repo.git?sig=abc,def failed") + if containsAny(out, []string{"abc", "def"}) { + t.Fatalf("query value leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa,ss@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("comma-containing userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa;http:ss@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("semicolon-containing userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa,http:ss?x@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("comma/query malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa,https://ss@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("comma/url-marker malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa/ss,https://x@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa/ss"}) { + t.Fatalf("slash comma/url-marker malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa/ss;https://x@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa/ss"}) { + t.Fatalf("slash semicolon/url-marker malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://use,r:pass@example.com/org/repo.git failed") + if containsAny(out, []string{"use", "pass"}) { + t.Fatalf("comma-split username leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://use;r:pass@example.com/org/repo.git failed") + if containsAny(out, []string{"use", "pass"}) { + t.Fatalf("semicolon-split username leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:pa;http:ss#x@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "pa", "ss"}) { + t.Fatalf("semicolon/fragment malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactRemoteURL("https://user:pa/ss?x@example.com/org/repo.git") + if containsAny(out, []string{"user", "pa/ss"}) { + t.Fatalf("slash query malformed userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("https://user:pa/ss#x@example.com/org/repo.git") + if containsAny(out, []string{"user", "pa/ss"}) { + t.Fatalf("slash fragment malformed userinfo leaked in output: %s", out) + } + + out = RedactString("fatal: https://user:8443/ss,https://x@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "8443/ss"}) { + t.Fatalf("numeric slash comma malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https://user:8443/ss;https://x@example.com/org/repo.git failed") + if containsAny(out, []string{"user", "8443/ss"}) { + t.Fatalf("numeric slash semicolon malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: what? failed issue#123") + if out != "fatal: what? failed issue#123" { + t.Fatalf("ordinary punctuation was redacted: %s", out) + } + + out = RedactString("contact dev@example.com? issue#123") + if out != "contact dev@example.com? issue#123" { + t.Fatalf("ordinary email punctuation was redacted: %s", out) + } + + out = RedactString("fatal: https//ghp_secret,http:ss?x@example.com/org/repo.git failed") + if containsAny(out, []string{"ghp_secret", "ss"}) { + t.Fatalf("missing-colon comma/query malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactString("fatal: https//ghp_secret;http:ss#x@example.com/org/repo.git failed") + if containsAny(out, []string{"ghp_secret", "ss"}) { + t.Fatalf("missing-colon semicolon/fragment malformed userinfo leaked in redacted string: %s", out) + } + + out = RedactRemoteURL("https://github.com/%zz?access_token=abc@def") + if containsAny(out, []string{"abc", "def"}) { + t.Fatalf("malformed query with at sign leaked in output: %s", out) + } + + out = RedactRemoteURL("alice:ghp_secret@github.com:org/repo.git") + if containsAny(out, []string{"alice", "ghp_secret"}) { + t.Fatalf("scheme-less malformed userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("x-token-auth:secret@bitbucket.org/org/repo.git") + if containsAny(out, []string{"x-token-auth", "secret"}) { + t.Fatalf("scheme-less token userinfo leaked in output: %s", out) + } + + out = RedactRemoteURL("alice:ghp_secret#@github.com:org/repo.git") + if containsAny(out, []string{"alice", "ghp_secret"}) { + t.Fatalf("scheme-less delimiter-split userinfo leaked in output: %s", out) + } +} + +func TestHasInlineCredentials(t *testing.T) { + tests := []struct { + name string + raw string + want bool + }{ + { + name: "https userinfo", + raw: "https://token@example.com/org/repo.git", + want: true, + }, + { + name: "token query parameter", + raw: "https://github.com/org/repo.git?access_token=secret", + want: true, + }, + { + name: "token-like query parameter", + raw: "https://github.com/org/repo.git?x-token-auth=secret", + want: true, + }, + { + name: "unrecognized query secret", + raw: "https://github.com/org/repo.git?sig=s3cr3t", + want: true, + }, + { + name: "benign-looking query", + raw: "https://github.com/org/repo.git?ref=main", + want: true, + }, + { + name: "fragment secret", + raw: "https://github.com/org/repo.git#access_token=secret", + want: true, + }, + { + name: "empty fragment marker", + raw: "https://github.com/org/repo.git#", + want: true, + }, + { + name: "malformed https userinfo path split", + raw: "https://ghp_secret/@github.com/org/repo.git", + want: true, + }, + { + name: "malformed https username path continuation", + raw: "https://ghp_secret/part@example.com/org/repo.git", + want: true, + }, + { + name: "malformed https username query continuation", + raw: "https://ghp_secret/part?x@example.com/org/repo.git", + want: true, + }, + { + name: "malformed https username single-label path", + raw: "https://ghp_secret/part@example.com", + want: true, + }, + { + name: "malformed https username ported single-label path", + raw: "https://ghp_secret:8443/part@example.com", + want: true, + }, + { + name: "malformed https userinfo missing slash", + raw: "https:/user:ghp_secret@github.com/org/repo.git", + want: true, + }, + { + name: "malformed https userinfo numeric prefix", + raw: "https://user:123/ss@example.com/org/repo.git", + want: true, + }, + { + name: "malformed https userinfo missing colon", + raw: "https//ghp_secret@github.com/org/repo.git", + want: true, + }, + { + name: "malformed https parse error missing colon", + raw: "https//ghp_secret%zz@github.com/org/repo.git", + want: true, + }, + { + name: "empty query marker", + raw: "https://github.com/org/repo.git?", + want: true, + }, + { + name: "malformed userinfo", + raw: "https://user%zz@github.com/org/repo.git", + want: true, + }, + { + name: "malformed query secret", + raw: "https://github.com/org/%zz.git?sig=s3cr3t", + want: true, + }, + { + name: "scp-style ssh", + raw: "git@github.com:org/repo.git", + want: false, + }, + { + name: "scheme-less malformed userinfo", + raw: "alice:ghp_secret@github.com:org/repo.git", + want: true, + }, + { + name: "scheme-less bitbucket token userinfo", + raw: "x-token-auth:secret@bitbucket.org/org/repo.git", + want: true, + }, + { + name: "scp-style ssh with query", + raw: "git@github.com:org/repo.git?sig=s3cr3t", + want: true, + }, + { + name: "scp-style ssh with fragment", + raw: "git@github.com:org/repo.git#access_token=secret", + want: true, + }, + { + name: "scp-style root path with at sign", + raw: "git@example.com:repo:v1@2026.git", + want: false, + }, + { + name: "ssh username only", + raw: "ssh://git@github.com/org/repo.git", + want: false, + }, + { + name: "ssh token username", + raw: "ssh://ghp_abcdefghijklmnopqrstuvwxyz@github.com/org/repo.git", + want: true, + }, + { + name: "ssh username containing token", + raw: "ssh://token-admin@git.example.com/org/repo.git", + want: false, + }, + { + name: "git protocol username only", + raw: "git://ghp_secret@github.com/org/repo.git", + want: true, + }, + { + name: "ssh username with query", + raw: "ssh://git@github.com/org/repo.git?ref=main", + want: true, + }, + { + name: "ssh password", + raw: "ssh://git:secret@github.com/org/repo.git", + want: true, + }, + { + name: "plain https remote", + raw: "https://github.com/org/repo.git", + want: false, + }, + { + name: "file URL path with at sign", + raw: "file:///tmp/repo@2026.git", + want: false, + }, + { + name: "https path with at sign", + raw: "https://git.example.com/team/repo@2026.git", + want: false, + }, + { + name: "https ported path with at sign", + raw: "https://git.example.com:8443/team/repo@2026.git", + want: false, + }, + { + name: "https path with colon and at sign", + raw: "https://git.example.com/team/repo:v1@2026.git", + want: false, + }, + { + name: "user subdomain path with colon and at sign", + raw: "https://user.example.com/team/repo:v1@2026.git", + want: false, + }, + { + name: "root https path with colon and at sign", + raw: "https://git.example.com/repo:v1@2026.git", + want: false, + }, + { + name: "single label https path with colon and at sign", + raw: "https://git/repo:v1@2026.git", + want: false, + }, + { + name: "single label ported https path with at sign", + raw: "https://git:8443/team/repo@2026.git", + want: false, + }, + { + name: "single label ported root https path with at sign", + raw: "https://git:8443/repo@2026.git", + want: false, + }, + { + name: "ghe ported https path with at sign", + raw: "https://ghe:8443/team/repo@2026.git", + want: false, + }, + { + name: "real host path with at sign continuation", + raw: "https://git.example.com/team/repo@2026/archive.git", + want: false, + }, + { + name: "real host ported path with at sign continuation", + raw: "https://git.example.com:8443/team/repo@2026/archive.git", + want: false, + }, + { + name: "localhost path with colon and at sign", + raw: "https://localhost/team/repo:v1@2026.git", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HasInlineCredentials(tt.raw); got != tt.want { + t.Fatalf("HasInlineCredentials(%q) = %v, want %v", tt.raw, got, tt.want) + } + }) + } } func containsAny(s string, needles []string) bool { diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 11a87fa..4ff7860 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/cloudflare/artifact-fs/internal/auth" "github.com/cloudflare/artifact-fs/internal/daemon" "github.com/cloudflare/artifact-fs/internal/logging" "github.com/cloudflare/artifact-fs/internal/model" @@ -63,12 +64,27 @@ func Run(ctx context.Context, args []string, stdout io.Writer, stderr io.Writer) ucli.StringFlag{Name: "branch", Value: "main", Usage: "branch to track"}, ucli.StringFlag{Name: "refresh", Value: "30s", Usage: "refresh interval"}, ucli.StringFlag{Name: "mount-root", Usage: "override mount root"}, + ucli.BoolFlag{Name: "async", Usage: "return after registration and prepare the repo in the daemon"}, + ucli.BoolFlag{Name: "prepared-gitdir", Usage: "use an existing git dir for async preparation"}, + ucli.StringFlag{Name: "git-dir", Usage: "explicit git dir path"}, + ucli.StringFlag{Name: "fetch-ref", Usage: "ref to fetch during async preparation"}, }, Action: withService(ctx, root, stderr, func(c *ucli.Context, svc *daemon.Service) error { name := strings.TrimSpace(c.String("name")) remote := strings.TrimSpace(c.String("remote")) - if name == "" || remote == "" { - return fmt.Errorf("--name and --remote are required") + async := c.Bool("async") + preparedGitDir := c.Bool("prepared-gitdir") + if preparedGitDir && !async { + return fmt.Errorf("--prepared-gitdir requires --async") + } + if preparedGitDir && strings.TrimSpace(c.String("git-dir")) == "" { + return fmt.Errorf("--git-dir is required with --prepared-gitdir") + } + if name == "" { + return fmt.Errorf("--name is required") + } + if remote == "" && !preparedGitDir { + return fmt.Errorf("--remote is required") } d, err := daemon.ParseRefresh(c.String("refresh")) if err != nil { @@ -81,11 +97,18 @@ func Run(ctx context.Context, args []string, stdout io.Writer, stderr io.Writer) Branch: c.String("branch"), RefreshInterval: d, MountRoot: c.String("mount-root"), + GitDir: c.String("git-dir"), + PreparedGitDir: preparedGitDir, + FetchRef: c.String("fetch-ref"), Enabled: true, } - if err := svc.AddRepo(ctx, cfg); err != nil { + if err := svc.AddRepoWithOptions(ctx, cfg, daemon.AddRepoOptions{Async: async}); err != nil { return err } + if async { + fmt.Fprintf(stdout, "queued %s\n", cfg.Name) + return nil + } fmt.Fprintf(stdout, "added %s\n", cfg.Name) return nil }), @@ -112,6 +135,13 @@ func Run(ctx context.Context, args []string, stdout io.Writer, stderr io.Writer) fmt.Fprintf(w, "fetched %s\n", name) return nil }), + nameCommand("prepare", "retry async repository preparation", ctx, root, stderr, stdout, func(c context.Context, svc *daemon.Service, name string, w io.Writer) error { + if err := svc.Prepare(c, name); err != nil { + return err + } + fmt.Fprintf(w, "preparing %s\n", name) + return nil + }), { Name: "list-repos", Usage: "list configured repos", @@ -216,10 +246,11 @@ func nameCommand(name, usage string, ctx context.Context, root string, stderr io } func formatStatusLine(st model.RepoRuntimeState) string { - return fmt.Sprintf("repo=%s state=%s head=%s ref=%s ahead=%d behind=%d diverged=%t last_fetch=%s result=%s hydrated_blobs=%d hydrated_bytes=%d overlay_dirty=%t", + return fmt.Sprintf("repo=%s state=%s head=%s ref=%s ahead=%d behind=%d diverged=%t last_fetch=%s result=%s prepare_error=%s hydrated_blobs=%d hydrated_bytes=%d overlay_dirty=%t", st.RepoID, st.State, st.CurrentHEADOID, st.CurrentHEADRef, st.AheadCount, st.BehindCount, st.Diverged, formatLastFetchAt(st.LastFetchAt), formatLastFetchResult(st.LastFetchResult), + formatPrepareError(st.PrepareError), st.HydratedBlobCount, st.HydratedBlobBytes, st.DirtyOverlay) } @@ -237,6 +268,13 @@ func formatLastFetchResult(result string) string { return result } +func formatPrepareError(err string) string { + if strings.TrimSpace(err) == "" { + return "none" + } + return strings.Join(strings.Fields(auth.RedactString(err)), "_") +} + func stubCommand(name, usage string, stdout io.Writer) ucli.Command { return ucli.Command{ Name: name, diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 05ad6bd..84e9d1a 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -1,6 +1,8 @@ package cli import ( + "bytes" + "context" "strings" "testing" "time" @@ -23,6 +25,7 @@ func TestFormatStatusLineUsesNeverForUnsetFetch(t *testing.T) { for _, want := range []string{ "last_fetch=never", "result=never", + "prepare_error=none", "hydrated_blobs=3", "hydrated_bytes=42", } { @@ -44,3 +47,113 @@ func TestFormatStatusLineFormatsFetchTimestamp(t *testing.T) { t.Fatalf("status line %q missing formatted timestamp", got) } } + +func TestFormatStatusLineKeepsPrepareErrorSingleLine(t *testing.T) { + st := model.RepoRuntimeState{PrepareError: "fatal: clone failed\ntry again\tlater"} + + got := formatStatusLine(st) + if strings.ContainsAny(got, "\n\t") { + t.Fatalf("status line contains raw whitespace: %q", got) + } + if !strings.Contains(got, "prepare_error=fatal:_clone_failed_try_again_later") { + t.Fatalf("status line %q missing normalized prepare error", got) + } +} + +func TestFormatStatusLineRedactsPrepareError(t *testing.T) { + st := model.RepoRuntimeState{PrepareError: "clone https://token@example.com/org/repo.git?access_token=secret failed"} + + got := formatStatusLine(st) + if strings.Contains(got, "token@example.com") || strings.Contains(got, "secret") { + t.Fatalf("status line leaked prepare error credential: %q", got) + } + if !strings.Contains(got, "REDACTED") { + t.Fatalf("status line %q missing redaction marker", got) + } +} + +func TestAddRepoAsyncCLIRegistersWithoutClone(t *testing.T) { + t.Setenv("ARTIFACT_FS_ROOT", t.TempDir()) + var stdout, stderr bytes.Buffer + + code := Run(context.Background(), []string{ + "add-repo", + "--name", "repo", + "--remote", "https://github.com/example/repo.git", + "--branch", "main", + "--async", + }, &stdout, &stderr) + if code != 0 { + t.Fatalf("Run exit = %d, stderr=%q", code, stderr.String()) + } + if got := stdout.String(); got != "queued repo\n" { + t.Fatalf("stdout = %q, want queued repo", got) + } +} + +func TestAddRepoAsyncCLIFlagValidation(t *testing.T) { + tests := []struct { + name string + args []string + want string + }{ + { + name: "prepared_gitdir_requires_async", + args: []string{"add-repo", "--name", "repo", "--prepared-gitdir", "--git-dir", "/tmp/repo.git"}, + want: "--prepared-gitdir requires --async", + }, + { + name: "prepared_gitdir_requires_git_dir", + args: []string{"add-repo", "--name", "repo", "--async", "--prepared-gitdir"}, + want: "--git-dir is required with --prepared-gitdir", + }, + { + name: "async_clone_requires_remote", + args: []string{"add-repo", "--name", "repo", "--async"}, + want: "--remote is required", + }, + { + name: "async_rejects_inline_credentials", + args: []string{"add-repo", "--name", "repo", "--async", "--remote", "https://token@example.com/org/repo.git"}, + want: "async repositories must use ambient credentials", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("ARTIFACT_FS_ROOT", t.TempDir()) + var stdout, stderr bytes.Buffer + code := Run(context.Background(), tt.args, &stdout, &stderr) + if code == 0 { + t.Fatalf("Run unexpectedly succeeded, stdout=%q", stdout.String()) + } + if !strings.Contains(stderr.String(), tt.want) { + t.Fatalf("stderr = %q, want substring %q", stderr.String(), tt.want) + } + }) + } +} + +func TestPrepareCLIReportsQueuedPrepare(t *testing.T) { + t.Setenv("ARTIFACT_FS_ROOT", t.TempDir()) + var stdout, stderr bytes.Buffer + ctx := context.Background() + code := Run(ctx, []string{ + "add-repo", + "--name", "repo", + "--remote", "https://github.com/example/repo.git", + "--async", + }, &stdout, &stderr) + if code != 0 { + t.Fatalf("add-repo exit = %d, stderr=%q", code, stderr.String()) + } + + stdout.Reset() + stderr.Reset() + code = Run(ctx, []string{"prepare", "--name", "repo"}, &stdout, &stderr) + if code != 0 { + t.Fatalf("prepare exit = %d, stderr=%q", code, stderr.String()) + } + if got := stdout.String(); got != "preparing repo\n" { + t.Fatalf("stdout = %q, want preparing repo", got) + } +} diff --git a/internal/daemon/async_test.go b/internal/daemon/async_test.go new file mode 100644 index 0000000..a9add38 --- /dev/null +++ b/internal/daemon/async_test.go @@ -0,0 +1,760 @@ +package daemon + +import ( + "context" + "errors" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/cloudflare/artifact-fs/internal/fusefs" + "github.com/cloudflare/artifact-fs/internal/model" + "github.com/cloudflare/artifact-fs/internal/snapshot" +) + +func TestAddRepoAsyncRegistersWithoutClone(t *testing.T) { + ctx := context.Background() + root := t.TempDir() + svc, err := New(ctx, root, slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "https://github.com/example/repo.git", + Branch: "master", + RefreshInterval: time.Minute, + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStatePreparing { + t.Fatalf("PrepareState = %q, want preparing", got.PrepareState) + } + if got.FetchRef != "master" { + t.Fatalf("FetchRef = %q, want master", got.FetchRef) + } + if got.GitDir != filepath.Join(root, "repos", "repo", "git") { + t.Fatalf("GitDir = %q", got.GitDir) + } +} + +func TestAddRepoAsyncRejectsInlineCredentials(t *testing.T) { + ctx := context.Background() + svc, err := New(ctx, t.TempDir(), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer func() { + svc.mu.Lock() + delete(svc.preparing, model.RepoID("repo")) + svc.mu.Unlock() + _ = svc.Close() + }() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "https://token@example.com/org/repo.git", + Branch: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err == nil { + t.Fatal("expected inline credential error") + } +} + +func TestRunPrepareRejectsPersistedInlineCredentialsBeforeClone(t *testing.T) { + ctx := context.Background() + root := t.TempDir() + svc, err := New(ctx, root, slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "https://token@example.com/org/repo.git", + Branch: "master", + FetchRef: "master", + PrepareState: model.PrepareStatePreparing, + RefreshInterval: time.Minute, + Enabled: true, + } + svc.fillPaths(&cfg) + if err := svc.registry.AddRepo(ctx, cfg); err != nil { + t.Fatal(err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + + if err := svc.runPrepare(ctx, got); err == nil { + t.Fatal("expected runPrepare failure") + } + if _, err := os.Stat(got.GitDir); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("git dir stat = %v, want not exist", err) + } + got, err = svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStateFailed { + t.Fatalf("PrepareState = %q, want failed", got.PrepareState) + } + if strings.Contains(got.PrepareError, "token") { + t.Fatalf("PrepareError was not redacted: %q", got.PrepareError) + } +} + +func TestAddRepoPreparedGitDirValidation(t *testing.T) { + ctx := context.Background() + svc, err := New(ctx, t.TempDir(), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + Branch: "master", + PreparedGitDir: true, + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{}); err == nil { + t.Fatal("expected --prepared-gitdir requires --async error") + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err == nil { + t.Fatal("expected --git-dir required error") + } +} + +func TestSyncReposSkipsResetWhilePrepareWorkerInFlight(t *testing.T) { + ctx := context.Background() + svc, err := New(ctx, t.TempDir(), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer func() { + svc.mu.Lock() + delete(svc.preparing, model.RepoID("repo")) + svc.mu.Unlock() + _ = svc.Close() + }() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "https://github.com/example/repo.git", + Branch: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + cfg, err = svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + + gate := fusefs.NewReadyGate(true) + rt := &repoRuntime{ + cfg: cfg, + gate: gate, + active: true, + state: model.RepoRuntimeState{ + RepoID: cfg.ID, + State: repoStateMounted, + PrepareError: "", + }, + } + svc.mu.Lock() + svc.running[cfg.ID] = rt + svc.preparing[cfg.ID] = 1 + svc.mu.Unlock() + + if err := svc.syncRepos(ctx); err != nil { + t.Fatal(err) + } + if rt.state.State != repoStateMounted { + t.Fatalf("runtime state = %q, want mounted", rt.state.State) + } + if err := gate.Wait(ctx); err != nil { + t.Fatalf("gate was reset while prepare worker was in flight: %v", err) + } +} + +func TestRestartRunningPrepareSkipsStalePreparingSnapshot(t *testing.T) { + ctx := context.Background() + svc, err := New(ctx, t.TempDir(), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer func() { + svc.mu.Lock() + delete(svc.running, model.RepoID("repo")) + delete(svc.preparing, model.RepoID("repo")) + svc.mu.Unlock() + _ = svc.Close() + }() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "https://github.com/example/repo.git", + Branch: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + latest, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if err := svc.registry.UpdatePrepareState(ctx, latest.ID, model.PrepareStateReady, ""); err != nil { + t.Fatal(err) + } + stale := latest + stale.PrepareState = model.PrepareStatePreparing + + gate := fusefs.NewReadyGate(true) + rt := &repoRuntime{ + cfg: latest, + gate: gate, + active: true, + state: model.RepoRuntimeState{ + RepoID: latest.ID, + State: repoStateMounted, + }, + } + svc.mu.Lock() + svc.running[latest.ID] = rt + svc.mu.Unlock() + + svc.restartRunningPrepareIfCurrent(ctx, stale, rt, false) + if rt.state.State != repoStateMounted { + t.Fatalf("runtime state = %q, want mounted", rt.state.State) + } + if err := gate.Wait(ctx); err != nil { + t.Fatalf("gate was reset from stale registry snapshot: %v", err) + } + svc.mu.Lock() + _, preparing := svc.preparing[latest.ID] + svc.mu.Unlock() + if preparing { + t.Fatal("started prepare worker from stale registry snapshot") + } +} + +func TestRunPrepareFailurePersistsRedactedError(t *testing.T) { + ctx := context.Background() + svc, err := New(ctx, t.TempDir(), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + Branch: "master", + GitDir: filepath.Join(t.TempDir(), "missing.git"), + PreparedGitDir: true, + FetchRef: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if err := svc.runPrepare(ctx, got); err == nil { + t.Fatal("expected runPrepare failure") + } + got, err = svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStateFailed { + t.Fatalf("PrepareState = %q, want failed", got.PrepareState) + } + if got.PrepareError == "" { + t.Fatal("PrepareError is empty, want persisted failure") + } + + if err := svc.setPrepareState(ctx, got, model.PrepareStateFailed, errors.New("clone https://token@example.com/org/repo.git failed")); err != nil { + t.Fatal(err) + } + got, err = svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if strings.Contains(got.PrepareError, "token") { + t.Fatalf("PrepareError was not redacted: %q", got.PrepareError) + } +} + +func TestStartPrepareWorkerTimesOutAndPersistsFailed(t *testing.T) { + ctx := context.Background() + bin := filepath.Join(t.TempDir(), "bin") + if err := os.MkdirAll(bin, 0o755); err != nil { + t.Fatal(err) + } + gitPath := filepath.Join(bin, "git") + if err := os.WriteFile(gitPath, []byte("#!/bin/sh\nexec sleep 10\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + svc, err := New(ctx, t.TempDir(), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + svc.prepareTimeout = 20 * time.Millisecond + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "https://github.com/example/repo.git", + Branch: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + svc.startPrepareWorker(ctx, got) + + got = waitForPrepareState(t, svc, "repo", model.PrepareStateFailed) + if got.PrepareError != "prepare timed out" { + t.Fatalf("PrepareError = %q, want timeout", got.PrepareError) + } + waitFor(t, time.Second, func() bool { + svc.mu.Lock() + defer svc.mu.Unlock() + _, preparing := svc.preparing[got.ID] + return !preparing + }) +} + +func TestRunPreparePreparedGitDirPublishesSnapshotAndMarksReady(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + + runCmd(t, "git", "init", "--bare", bare) + runCmd(t, "git", "clone", bare, work) + runCmd(t, "git", "-C", work, "checkout", "-b", "master") + if err := os.WriteFile(filepath.Join(work, "README.md"), []byte("hello\n"), 0o644); err != nil { + t.Fatal(err) + } + runCmd(t, "git", "-C", work, "add", "README.md") + runCmd(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + runCmd(t, "git", "-C", work, "push", "origin", "master") + + runCmd(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + runCmd(t, "git", "-C", preparedWorktree, "remote", "add", "origin", "file://"+bare) + + root := filepath.Join(tmp, "artifact-fs") + svc, err := New(ctx, root, slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + Branch: "master", + RefreshInterval: time.Minute, + GitDir: preparedGitDir, + PreparedGitDir: true, + FetchRef: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if err := svc.runPrepare(ctx, got); err != nil { + t.Fatalf("runPrepare: %v", err) + } + + got, err = svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStateReady { + t.Fatalf("PrepareState = %q, want ready", got.PrepareState) + } + if got.PrepareError != "" { + t.Fatalf("PrepareError = %q, want empty", got.PrepareError) + } + snap, err := snapshot.New(ctx, got.MetaDBPath) + if err != nil { + t.Fatal(err) + } + defer snap.Close() + _, ref, gen, err := snap.ReadState(ctx) + if err != nil { + t.Fatal(err) + } + if ref != "master" { + t.Fatalf("snapshot ref = %q, want master", ref) + } + if gen == 0 { + t.Fatal("snapshot generation = 0, want published generation") + } + if _, ok := snap.GetNode(gen, "README.md"); !ok { + t.Fatal("README.md not found in snapshot") + } +} + +func TestSizeUpdateBatcherFlushesOnStop(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + snap, err := snapshot.New(ctx, filepath.Join(tmp, "snap.sqlite")) + if err != nil { + t.Fatal(err) + } + defer snap.Close() + gen, err := snap.PublishGeneration(ctx, "head", "master", []model.BaseNode{ + {RepoID: "repo", Path: ".", Type: "dir", Mode: 0o755, SizeState: "known"}, + {RepoID: "repo", Path: "a.txt", Type: "file", Mode: 0o644, ObjectOID: "a", SizeState: "unknown"}, + {RepoID: "repo", Path: "b.txt", Type: "file", Mode: 0o644, ObjectOID: "b", SizeState: "unknown"}, + }) + if err != nil { + t.Fatal(err) + } + + runCtx, cancel := context.WithCancel(ctx) + batcher := newSizeUpdateBatcher(snap, slog.New(slog.NewTextHandler(io.Discard, nil)), "repo") + batcher.Start(runCtx) + batcher.Add(gen, "a", 10) + batcher.Add(gen, "b", 20) + cancel() + batcher.Stop() + + n, ok := snap.GetNode(gen, "a.txt") + if !ok { + t.Fatal("a.txt not found") + } + if n.SizeState != "known" || n.SizeBytes != 10 { + t.Fatalf("a.txt size = %s/%d, want known/10", n.SizeState, n.SizeBytes) + } + n, ok = snap.GetNode(gen, "b.txt") + if !ok { + t.Fatal("b.txt not found") + } + if n.SizeState != "known" || n.SizeBytes != 20 { + t.Fatalf("b.txt size = %s/%d, want known/20", n.SizeState, n.SizeBytes) + } +} + +func TestRunPrepareFreshCloneSkipsSecondFetchForBranchFetchRef(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + + runCmd(t, "git", "init", "--bare", bare) + runCmd(t, "git", "clone", bare, work) + runCmd(t, "git", "-C", work, "checkout", "-b", "master") + if err := os.WriteFile(filepath.Join(work, "README.md"), []byte("hello\n"), 0o644); err != nil { + t.Fatal(err) + } + runCmd(t, "git", "-C", work, "add", "README.md") + runCmd(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + runCmd(t, "git", "-C", work, "push", "origin", "master") + + for _, fetchRef := range []string{"master", "refs/heads/master", "origin/master", "refs/remotes/origin/master"} { + t.Run(fetchRef, func(t *testing.T) { + svc, err := New(ctx, filepath.Join(tmp, "artifact-fs", strings.ReplaceAll(fetchRef, "/", "-")), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + RemoteURL: "file://" + bare, + Branch: "master", + FetchRef: fetchRef, + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + + bin := filepath.Join(t.TempDir(), "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + logPath := filepath.Join(t.TempDir(), "git.log") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$GIT_COMMAND_LOG\"\nexec /usr/bin/git \"$@\"\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_COMMAND_LOG", logPath) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + if err := svc.runPrepare(ctx, got); err != nil { + t.Fatalf("runPrepare: %v", err) + } + logData, err := os.ReadFile(logPath) + if err != nil { + t.Fatal(err) + } + for _, line := range strings.Split(string(logData), "\n") { + if strings.HasPrefix(line, "fetch ") { + t.Fatalf("fresh branch clone ran redundant fetch; git log:\n%s", logData) + } + } + + got, err = svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStateReady { + t.Fatalf("PrepareState = %q, want ready", got.PrepareState) + } + }) + } +} + +func TestRunPrepareDoesNotOpenGateWhenReadyPersistenceFails(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + preparedGitDir := createPreparedGitDir(t, tmp) + + svc, err := New(ctx, filepath.Join(tmp, "artifact-fs"), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + Branch: "master", + RefreshInterval: time.Minute, + GitDir: preparedGitDir, + PreparedGitDir: true, + FetchRef: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if err := svc.mountAsyncRepo(ctx, got); err != nil { + t.Fatal(err) + } + svc.mu.Lock() + gate := svc.running[got.ID].gate + svc.mu.Unlock() + if gate == nil { + t.Fatal("runtime gate is nil") + } + if err := svc.registry.Close(); err != nil { + t.Fatal(err) + } + + if err := svc.runPrepare(ctx, got); err == nil { + t.Fatal("expected ready state persistence failure") + } + waitCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer cancel() + if err := gate.Wait(waitCtx); !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("gate wait = %v, want deadline because ready was not durably persisted", err) + } +} + +func TestRunPrepareDoesNotMarkSupersededConfigReady(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + firstGitDir := createPreparedGitDir(t, filepath.Join(tmp, "first")) + secondGitDir := createPreparedGitDir(t, filepath.Join(tmp, "second")) + + svc, err := New(ctx, filepath.Join(tmp, "artifact-fs"), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + Branch: "master", + RefreshInterval: time.Minute, + GitDir: firstGitDir, + PreparedGitDir: true, + FetchRef: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + first, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + cfg.GitDir = secondGitDir + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + + err = svc.runPrepare(ctx, first) + if err == nil || !strings.Contains(err.Error(), "superseded") { + t.Fatalf("runPrepare error = %v, want superseded", err) + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStatePreparing { + t.Fatalf("PrepareState = %q, want preparing", got.PrepareState) + } + if got.GitDir != secondGitDir { + t.Fatalf("GitDir = %q, want newer git dir %q", got.GitDir, secondGitDir) + } +} + +func TestRunPrepareDoesNotMarkSupersededConfigFailed(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + svc, err := New(ctx, filepath.Join(tmp, "artifact-fs"), slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatal(err) + } + defer svc.Close() + + cfg := model.RepoConfig{ + Name: "repo", + ID: "repo", + Branch: "master", + RefreshInterval: time.Minute, + GitDir: filepath.Join(tmp, "missing.git"), + PreparedGitDir: true, + FetchRef: "master", + Enabled: true, + } + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + first, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + cfg.GitDir = createPreparedGitDir(t, filepath.Join(tmp, "second")) + if err := svc.AddRepoWithOptions(ctx, cfg, AddRepoOptions{Async: true}); err != nil { + t.Fatal(err) + } + + if err := svc.runPrepare(ctx, first); err == nil { + t.Fatal("expected stale prepare failure") + } + got, err := svc.registry.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if got.PrepareState != model.PrepareStatePreparing { + t.Fatalf("PrepareState = %q, want preparing", got.PrepareState) + } + if got.PrepareError != "" { + t.Fatalf("PrepareError = %q, want empty", got.PrepareError) + } +} + +func createPreparedGitDir(t *testing.T, tmp string) string { + t.Helper() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + + runCmd(t, "git", "init", "--bare", bare) + runCmd(t, "git", "clone", bare, work) + runCmd(t, "git", "-C", work, "checkout", "-b", "master") + if err := os.WriteFile(filepath.Join(work, "README.md"), []byte("hello\n"), 0o644); err != nil { + t.Fatal(err) + } + runCmd(t, "git", "-C", work, "add", "README.md") + runCmd(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + runCmd(t, "git", "-C", work, "push", "origin", "master") + + runCmd(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + runCmd(t, "git", "-C", preparedWorktree, "remote", "add", "origin", "file://"+bare) + return preparedGitDir +} + +func waitForPrepareState(t *testing.T, svc *Service, name string, state string) model.RepoConfig { + t.Helper() + var got model.RepoConfig + waitFor(t, 2*time.Second, func() bool { + var err error + got, err = svc.registry.GetRepo(context.Background(), name) + return err == nil && got.PrepareState == state + }) + return got +} + +func waitFor(t *testing.T, timeout time.Duration, ready func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if ready() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("condition was not met before timeout") +} + +func runCmd(t *testing.T, name string, args ...string) { + t.Helper() + cmd := exec.Command(name, args...) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("%s %v failed: %v\n%s", name, args, err, string(out)) + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index d650f6a..ca2267e 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -26,15 +26,30 @@ import ( const DefaultHydrationConcurrency = 4 +const ( + defaultPrepareTimeout = 30 * time.Minute + prepareStateWriteTimeout = 5 * time.Second + sizeUpdateFlushInterval = 100 * time.Millisecond +) + +const ( + repoStateMounted = "mounted" + repoStateUnmounted = "unmounted" + repoStateDegraded = "degraded" +) + type Service struct { root string mountRoot string hydrationConcurrency int + prepareTimeout time.Duration logger *slog.Logger registry *registry.Store git *gitstore.Store mu sync.Mutex running map[model.RepoID]*repoRuntime + preparing map[model.RepoID]int64 + prepareSeq int64 mountFailures map[model.RepoID]*mountFailure } @@ -50,9 +65,12 @@ type repoRuntime struct { snapshot *snapshot.Store overlay *overlay.Store hydrator *hydrator.Service + sizes *sizeUpdateBatcher resolver *fusefs.Resolver mfs fusefs.MountedFS + gate *fusefs.ReadyGate state model.RepoRuntimeState + active bool } type aheadBehind struct { @@ -61,18 +79,24 @@ type aheadBehind struct { diverged bool } +type AddRepoOptions struct { + Async bool +} + func New(ctx context.Context, root string, logger *slog.Logger) (*Service, error) { reg, err := registry.New(ctx, filepath.Join(root, "config", "repos.sqlite")) if err != nil { return nil, err } svc := &Service{ - root: root, - logger: logger, - registry: reg, - git: gitstore.New(logger), - running: map[model.RepoID]*repoRuntime{}, - mountFailures: map[model.RepoID]*mountFailure{}, + root: root, + logger: logger, + registry: reg, + git: gitstore.New(logger), + prepareTimeout: defaultPrepareTimeout, + running: map[model.RepoID]*repoRuntime{}, + preparing: map[model.RepoID]int64{}, + mountFailures: map[model.RepoID]*mountFailure{}, } svc.git.SetBatchPoolSize(DefaultHydrationConcurrency) return svc, nil @@ -146,9 +170,21 @@ func (s *Service) syncRepos(ctx context.Context) error { continue } s.mu.Lock() - _, running := s.running[repo.ID] + rt, running := s.running[repo.ID] + _, alreadyPreparing := s.preparing[repo.ID] s.mu.Unlock() if running { + s.restartRunningPrepareIfCurrent(ctx, repo, rt, alreadyPreparing) + continue + } + if shouldMountAsync(repo) { + if err := s.mountAsyncRepo(ctx, repo); err != nil { + s.logger.Error("repo async mount failed", "repo", repo.Name, "error", err) + continue + } + if repo.PrepareState == model.PrepareStatePreparing { + s.startPrepareWorker(ctx, repo) + } continue } if mf, ok := s.mountFailures[repo.ID]; ok && time.Since(mf.lastAttempt) < mf.backoff { @@ -196,7 +232,71 @@ func (s *Service) syncRepos(ctx context.Context) error { return nil } +func (s *Service) restartRunningPrepareIfCurrent(ctx context.Context, repo model.RepoConfig, rt *repoRuntime, alreadyPreparing bool) { + if repo.PrepareState != model.PrepareStatePreparing || rt == nil { + return + } + latest, err := s.registry.GetRepo(ctx, repo.Name) + if err != nil { + s.logger.Error("repo prepare state refresh failed", "repo", repo.Name, "error", err) + return + } + if latest.PrepareState != model.PrepareStatePreparing { + return + } + configMatches := samePrepareConfig(rt.cfg, latest) + if alreadyPreparing && configMatches { + return + } + if rt.active || !configMatches { + s.unmount(latest.ID) + if err := s.mountAsyncRepo(ctx, latest); err != nil { + s.logger.Error("repo async remount failed", "repo", latest.Name, "error", err) + return + } + s.supersedePrepare(latest.ID) + s.startPrepareWorker(ctx, latest) + return + } + if alreadyPreparing { + return + } + if s.resetRunningPrepareState(latest) { + s.startPrepareWorker(ctx, latest) + } +} + +func samePrepareConfig(a model.RepoConfig, b model.RepoConfig) bool { + return a.Branch == b.Branch && + a.RemoteURL == b.RemoteURL && + a.PreparedGitDir == b.PreparedGitDir && + a.FetchRef == b.FetchRef && + a.GitDir == b.GitDir && + a.MetaDBPath == b.MetaDBPath && + a.OverlayDir == b.OverlayDir && + a.OverlayDBPath == b.OverlayDBPath && + a.BlobCacheDir == b.BlobCacheDir && + a.MountPath == b.MountPath +} + +func (s *Service) prepareConfigStillCurrent(ctx context.Context, cfg model.RepoConfig) bool { + latest, err := s.registry.GetRepo(ctx, cfg.Name) + if err != nil { + s.logger.Error("repo prepare state refresh failed", "repo", cfg.Name, "error", err) + return false + } + s.fillPaths(&latest) + if strings.TrimSpace(latest.FetchRef) == "" { + latest.FetchRef = latest.Branch + } + return samePrepareConfig(cfg, latest) +} + func (s *Service) AddRepo(ctx context.Context, cfg model.RepoConfig) error { + return s.AddRepoWithOptions(ctx, cfg, AddRepoOptions{}) +} + +func (s *Service) AddRepoWithOptions(ctx context.Context, cfg model.RepoConfig, opts AddRepoOptions) error { if err := model.ValidateRepoName(cfg.Name); err != nil { return err } @@ -207,12 +307,39 @@ func (s *Service) AddRepo(ctx context.Context, cfg model.RepoConfig) error { if cfg.RefreshInterval <= 0 { cfg.RefreshInterval = 30 * time.Second } + explicitGitDir := strings.TrimSpace(cfg.GitDir) != "" s.fillPaths(&cfg) + if strings.TrimSpace(cfg.FetchRef) == "" { + cfg.FetchRef = cfg.Branch + } + cloneURL := cfg.RemoteURL + if cfg.PreparedGitDir && !opts.Async { + return fmt.Errorf("--prepared-gitdir requires --async") + } + if opts.Async { + if strings.TrimSpace(cfg.RemoteURL) == "" && !cfg.PreparedGitDir { + return fmt.Errorf("--remote is required unless --prepared-gitdir is set") + } + if cfg.PreparedGitDir && !explicitGitDir { + return fmt.Errorf("--git-dir is required with --prepared-gitdir") + } + if auth.HasInlineCredentials(cfg.RemoteURL) { + return fmt.Errorf("async repositories must use ambient credentials; remove credentials from --remote") + } + cfg.PrepareState = model.PrepareStatePreparing + cfg.PrepareError = "" + } else if auth.HasInlineCredentials(cfg.RemoteURL) { + cfg.RemoteURL = "" + } if err := s.registry.AddRepo(ctx, cfg); err != nil { return err } + if opts.Async { + return nil + } // Clone and build snapshot so the repo is ready to mount, but don't start // the FUSE server -- that's the daemon's job. + cfg.RemoteURL = cloneURL return s.prepareRepo(ctx, cfg) } @@ -273,6 +400,12 @@ func (s *Service) FetchNow(ctx context.Context, name string) error { if err != nil { return err } + if cfg.PrepareState == model.PrepareStatePreparing { + return fusefs.ErrRepoNotReady + } + if cfg.PrepareState == model.PrepareStateFailed { + return fmt.Errorf("repo prepare failed: %s", cfg.PrepareError) + } if err := s.git.Fetch(ctx, cfg); err != nil { return err } @@ -288,12 +421,48 @@ func (s *Service) FetchNow(ctx context.Context, name string) error { return nil } +func (s *Service) Prepare(ctx context.Context, name string) error { + cfg, err := s.registry.GetRepo(ctx, name) + if err != nil { + return err + } + if !isAsyncRepo(cfg) { + return s.prepareRepo(ctx, cfg) + } + if cfg.PrepareState == model.PrepareStateReady { + return nil + } + cfg.PrepareState = model.PrepareStatePreparing + cfg.PrepareError = "" + if strings.TrimSpace(cfg.FetchRef) == "" { + cfg.FetchRef = cfg.Branch + } + if err := s.registry.UpdatePrepareStateForConfig(ctx, cfg, model.PrepareStatePreparing, ""); err != nil { + return err + } + cfg.PrepareState = model.PrepareStatePreparing + cfg.PrepareError = "" + if s.resetRunningPrepareState(cfg) { + s.startPrepareWorker(ctx, cfg) + } + return nil +} + func (s *Service) Remount(ctx context.Context, name string) error { cfg, err := s.registry.GetRepo(ctx, name) if err != nil { return err } s.unmount(cfg.ID) + if shouldMountAsync(cfg) { + if err := s.mountAsyncRepo(ctx, cfg); err != nil { + return err + } + if cfg.PrepareState == model.PrepareStatePreparing { + s.startPrepareWorker(ctx, cfg) + } + return nil + } return s.mountRepo(ctx, cfg) } @@ -377,9 +546,12 @@ func (s *Service) mountRepo(ctx context.Context, cfg model.RepoConfig) error { resolver := &fusefs.Resolver{Snapshot: snap, Overlay: ov} resolver.SetGeneration(gen) s.refreshCommitTime(ctx, cfg, headOID, resolver, "commit timestamp unavailable, mtime will use generation fallback") + runtimeCtx, cancel := context.WithCancel(ctx) + sizes := newSizeUpdateBatcher(snap, s.logger, cfg.Name) + sizes.Start(runtimeCtx) h.SetOnHydrated(func(_ model.RepoID, objectOID string, size int64) { - snap.UpdateSize(resolver.Generation(), objectOID, size) + sizes.Add(resolver.Generation(), objectOID, size) }) h.Start(s.hydrationWorkers(), cfg) engine := &fusefs.Engine{ @@ -394,8 +566,73 @@ func (s *Service) mountRepo(ctx context.Context, cfg model.RepoConfig) error { s.logger.Error("fuse mount failed, running without FUSE", "repo", cfg.Name, "error", err) mfs = nil } + rt := &repoRuntime{ + cfg: cfg, + ctx: runtimeCtx, + cancel: cancel, + snapshot: snap, + overlay: ov, + hydrator: h, + sizes: sizes, + resolver: resolver, + mfs: mfs, + state: newRuntimeState(cfg.ID, headOID, headRef, gen), + } + s.startRuntime(rt) + s.startRepoBackground(rt) + + return nil +} + +func (s *Service) mountAsyncRepo(ctx context.Context, cfg model.RepoConfig) error { + s.fillPaths(&cfg) + if err := os.MkdirAll(cfg.MountPath, 0o755); err != nil { + return err + } + snap, err := snapshot.New(ctx, cfg.MetaDBPath) + if err != nil { + return err + } + headOID, headRef, gen, _ := snap.ReadState(ctx) + ov, err := overlay.New(ctx, cfg) + if err != nil { + snap.Close() + return err + } + h := hydrator.New(s.git) + resolver := &fusefs.Resolver{Snapshot: snap, Overlay: ov} + resolver.SetGeneration(gen) + if headOID != "" { + s.refreshCommitTime(ctx, cfg, headOID, resolver, "commit timestamp unavailable, mtime will use generation fallback") + } runtimeCtx, cancel := context.WithCancel(ctx) + sizes := newSizeUpdateBatcher(snap, s.logger, cfg.Name) + sizes.Start(runtimeCtx) + h.SetOnHydrated(func(_ model.RepoID, objectOID string, size int64) { + sizes.Add(resolver.Generation(), objectOID, size) + }) + h.Start(s.hydrationWorkers(), cfg) + gate := fusefs.NewReadyGate(false) + if cfg.PrepareState == model.PrepareStateFailed { + gate.MarkFailed(prepareGateError(cfg.PrepareError)) + } + engine := &fusefs.Engine{ + Resolver: resolver, + Repo: cfg, + Overlay: ov, + Hydrator: h, + } + + mfs, err := fusefs.MountRepoWithGate(cfg, resolver, engine, gate) + if err != nil { + s.logger.Error("fuse mount failed, running without FUSE", "repo", cfg.Name, "error", err) + mfs = nil + } + state := cfg.PrepareState + if strings.TrimSpace(state) == "" { + state = model.PrepareStatePreparing + } rt := &repoRuntime{ cfg: cfg, ctx: runtimeCtx, @@ -403,12 +640,291 @@ func (s *Service) mountRepo(ctx context.Context, cfg model.RepoConfig) error { snapshot: snap, overlay: ov, hydrator: h, + sizes: sizes, resolver: resolver, mfs: mfs, - state: newRuntimeState(cfg.ID, headOID, headRef, gen), + gate: gate, + state: model.RepoRuntimeState{ + RepoID: cfg.ID, + CurrentHEADOID: headOID, + CurrentHEADRef: headRef, + SnapshotGeneration: gen, + LastFetchResult: "never", + State: state, + PrepareError: cfg.PrepareError, + }, } s.startRuntime(rt) + return nil +} + +func (s *Service) startPrepareWorker(ctx context.Context, cfg model.RepoConfig) { + s.mu.Lock() + if _, ok := s.preparing[cfg.ID]; ok { + s.mu.Unlock() + return + } + s.prepareSeq++ + token := s.prepareSeq + workerCtx := ctx + if rt := s.running[cfg.ID]; rt != nil && rt.ctx != nil { + workerCtx = rt.ctx + } + s.preparing[cfg.ID] = token + s.mu.Unlock() + + go func() { + defer func() { + s.mu.Lock() + if s.preparing[cfg.ID] == token { + delete(s.preparing, cfg.ID) + } + s.mu.Unlock() + }() + prepareCtx, cancel := context.WithTimeout(workerCtx, s.prepareTimeoutDuration()) + defer cancel() + if err := s.runPrepare(prepareCtx, cfg); err != nil { + s.logger.Error("repo prepare failed", "repo", cfg.Name, "error", err) + } + }() +} + +func (s *Service) supersedePrepare(id model.RepoID) { + s.mu.Lock() + delete(s.preparing, id) + s.mu.Unlock() +} + +func (s *Service) prepareTimeoutDuration() time.Duration { + if s.prepareTimeout > 0 { + return s.prepareTimeout + } + return defaultPrepareTimeout +} + +func (s *Service) resetRunningPrepareState(cfg model.RepoConfig) bool { + s.mu.Lock() + defer s.mu.Unlock() + rt, ok := s.running[cfg.ID] + if !ok || rt.gate == nil { + return false + } + rt.gate.Reset() + rt.cfg = cfg + rt.state.State = model.PrepareStatePreparing + rt.state.PrepareError = "" + return true +} + +func (s *Service) runPrepare(ctx context.Context, cfg model.RepoConfig) error { + s.fillPaths(&cfg) + if strings.TrimSpace(cfg.FetchRef) == "" { + cfg.FetchRef = cfg.Branch + } + + fail := func(err error) error { + if errors.Is(ctx.Err(), context.Canceled) || errors.Is(err, context.Canceled) { + return err + } + stateErr := err + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + stateErr = errors.New("prepare timed out") + } + stateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), prepareStateWriteTimeout) + defer cancel() + if !s.prepareConfigStillCurrent(stateCtx, cfg) { + return err + } + _ = s.setPrepareState(stateCtx, cfg, model.PrepareStateFailed, stateErr) + return err + } + + if cfg.PreparedGitDir { + if err := s.git.ValidatePreparedGitDir(ctx, cfg); err != nil { + return fail(err) + } + if err := s.git.FetchRefNonInteractive(ctx, cfg, cfg.FetchRef); err != nil { + return fail(err) + } + if err := s.git.PrepareFetchedBranch(ctx, cfg, cfg.FetchRef); err != nil { + return fail(err) + } + } else { + if strings.TrimSpace(cfg.RemoteURL) == "" { + return fail(errors.New("remote URL is required for async clone")) + } + if _, err := os.Stat(cfg.GitDir); err == nil { + if err := s.git.PrepareExistingCloneNonInteractive(ctx, cfg); err != nil { + return fail(err) + } + } else { + if err := s.git.ValidateAmbientRemote(cfg); err != nil { + return fail(err) + } + if err := s.git.CloneBloblessNonInteractive(ctx, cfg); err != nil { + return fail(err) + } + if !sameBranchRef(cfg.FetchRef, cfg.Branch) { + if err := s.git.PrepareExistingCloneNonInteractive(ctx, cfg); err != nil { + return fail(err) + } + } + } + } + + headOID, headRef, err := s.git.ResolveHEAD(ctx, cfg) + if err != nil { + return fail(err) + } + snap, closeSnap, err := s.snapshotForPrepare(ctx, cfg) + if err != nil { + return fail(err) + } + if closeSnap { + defer snap.Close() + } + gen, _, err := s.publishSnapshot(ctx, cfg, snap, headOID, headRef) + if err != nil { + return fail(err) + } + latest, err := s.registry.GetRepo(ctx, cfg.Name) + if err != nil { + return fail(err) + } + s.fillPaths(&latest) + if strings.TrimSpace(latest.FetchRef) == "" { + latest.FetchRef = latest.Branch + } + if !samePrepareConfig(cfg, latest) { + return errors.New("prepare superseded by newer repo config") + } + if err := s.setPrepareStateBeforeReadyGate(ctx, cfg); err != nil { + return fail(err) + } + if err := s.completePreparedRuntime(ctx, cfg, headOID, headRef, gen); err != nil { + return fail(err) + } + return nil +} + +func sameBranchRef(fetchRef string, branch string) bool { + return branchRefName(fetchRef) == branchRefName(branch) +} + +func branchRefName(ref string) string { + ref = strings.TrimSpace(ref) + for _, prefix := range []string{"refs/heads/", "refs/remotes/origin/", "origin/"} { + if strings.HasPrefix(ref, prefix) { + return strings.TrimPrefix(ref, prefix) + } + } + if strings.HasPrefix(ref, "refs/") { + return "" + } + return ref +} + +func (s *Service) snapshotForPrepare(ctx context.Context, cfg model.RepoConfig) (*snapshot.Store, bool, error) { + s.mu.Lock() + rt := s.running[cfg.ID] + s.mu.Unlock() + if rt != nil && rt.snapshot != nil { + return rt.snapshot, false, nil + } + snap, err := snapshot.New(ctx, cfg.MetaDBPath) + if err != nil { + return nil, false, err + } + return snap, true, nil +} +func (s *Service) completePreparedRuntime(ctx context.Context, cfg model.RepoConfig, headOID string, headRef string, gen int64) error { + s.mu.Lock() + rt := s.running[cfg.ID] + if rt != nil && !samePrepareConfig(rt.cfg, cfg) { + s.mu.Unlock() + return registry.ErrRepoChanged + } + s.mu.Unlock() + if rt == nil { + return nil + } + if !s.prepareConfigStillCurrent(ctx, cfg) { + return registry.ErrRepoChanged + } + if !s.prepareConfigStillCurrent(ctx, cfg) { + return registry.ErrRepoChanged + } + baseLookup := func(path string) (model.BaseNode, bool) { + return rt.snapshot.GetNode(gen, path) + } + if err := rt.overlay.Reconcile(ctx, baseLookup); err != nil { + return err + } + if !s.prepareConfigStillCurrent(ctx, cfg) { + return registry.ErrRepoChanged + } + s.refreshCommitTime(ctx, cfg, headOID, rt.resolver, "commit timestamp unavailable") + rt.resolver.SetGeneration(gen) + s.mu.Lock() + if s.running[cfg.ID] != rt || !samePrepareConfig(rt.cfg, cfg) { + s.mu.Unlock() + return registry.ErrRepoChanged + } + rt.cfg = cfg + rt.cfg.PrepareState = model.PrepareStateReady + rt.cfg.PrepareError = "" + setHeadState(&rt.state, headOID, headRef, gen) + rt.state.State = repoStateMounted + rt.state.PrepareError = "" + s.mu.Unlock() + rt.gate.MarkReady() + s.startRepoBackground(rt) + return nil +} + +func (s *Service) setPrepareState(ctx context.Context, cfg model.RepoConfig, state string, stateErr error) error { + return s.applyPrepareState(ctx, cfg, state, stateErr, true) +} + +func (s *Service) setPrepareStateBeforeReadyGate(ctx context.Context, cfg model.RepoConfig) error { + return s.applyPrepareState(ctx, cfg, model.PrepareStateReady, nil, false) +} + +func (s *Service) applyPrepareState(ctx context.Context, cfg model.RepoConfig, state string, stateErr error, applyReadyRuntime bool) error { + msg := "" + if stateErr != nil { + msg = auth.RedactString(stateErr.Error()) + } + if err := s.registry.UpdatePrepareStateForConfig(ctx, cfg, state, msg); err != nil { + return err + } + s.mu.Lock() + if rt, ok := s.running[cfg.ID]; ok { + if !samePrepareConfig(rt.cfg, cfg) { + s.mu.Unlock() + return nil + } + rt.cfg.PrepareState = state + rt.cfg.PrepareError = msg + rt.state.PrepareError = msg + if state != model.PrepareStateReady || applyReadyRuntime { + rt.state.State = runtimeStateForPrepareState(state) + } + if rt.gate != nil { + switch state { + case model.PrepareStateFailed: + rt.gate.MarkFailed(prepareGateError(msg)) + case model.PrepareStateReady: + if applyReadyRuntime { + rt.gate.MarkReady() + } + default: + rt.gate.Reset() + } + } + } + s.mu.Unlock() return nil } @@ -505,9 +1021,11 @@ func (s *Service) refreshLoop(rt *repoRuntime) { func (s *Service) readPersistedStatus(ctx context.Context, cfg model.RepoConfig) model.RepoRuntimeState { // One-shot CLI process: reconstruct state from persisted stores and // OS-level mount check since we don't share memory with the daemon. - st := model.RepoRuntimeState{RepoID: cfg.ID, State: "unmounted", LastFetchResult: "never"} - if isMounted(cfg.MountPath) { - st.State = "mounted" + st := model.RepoRuntimeState{RepoID: cfg.ID, State: repoStateUnmounted, LastFetchResult: "never", PrepareError: cfg.PrepareError} + if isPendingOrFailedPrepareState(cfg.PrepareState) { + st.State = cfg.PrepareState + } else if isMounted(cfg.MountPath) { + st.State = repoStateMounted } if cfg.MetaDBPath != "" { if snap, err := snapshot.New(ctx, cfg.MetaDBPath); err == nil { @@ -547,6 +1065,90 @@ func (s *Service) publishSnapshot(ctx context.Context, cfg model.RepoConfig, sna return gen, "", nil } +type sizeUpdateBatcher struct { + snapshot *snapshot.Store + logger *slog.Logger + repoName string + interval time.Duration + stopOnce sync.Once + stopCh chan struct{} + done chan struct{} + mu sync.Mutex + pending map[int64]map[string]int64 + stopped bool +} + +func newSizeUpdateBatcher(snap *snapshot.Store, logger *slog.Logger, repoName string) *sizeUpdateBatcher { + return &sizeUpdateBatcher{ + snapshot: snap, + logger: logger, + repoName: repoName, + interval: sizeUpdateFlushInterval, + stopCh: make(chan struct{}), + done: make(chan struct{}), + pending: map[int64]map[string]int64{}, + } +} + +func (b *sizeUpdateBatcher) Start(ctx context.Context) { + go b.run(ctx) +} + +func (b *sizeUpdateBatcher) Add(generation int64, objectOID string, size int64) { + if generation <= 0 || strings.TrimSpace(objectOID) == "" { + return + } + b.mu.Lock() + defer b.mu.Unlock() + if b.stopped { + return + } + if b.pending[generation] == nil { + b.pending[generation] = map[string]int64{} + } + b.pending[generation][objectOID] = size +} + +func (b *sizeUpdateBatcher) Stop() { + b.stopOnce.Do(func() { + b.mu.Lock() + b.stopped = true + b.mu.Unlock() + close(b.stopCh) + <-b.done + b.Flush() + }) +} + +func (b *sizeUpdateBatcher) run(ctx context.Context) { + ticker := time.NewTicker(b.interval) + defer ticker.Stop() + defer close(b.done) + defer b.Flush() + for { + select { + case <-ctx.Done(): + return + case <-b.stopCh: + return + case <-ticker.C: + b.Flush() + } + } +} + +func (b *sizeUpdateBatcher) Flush() { + b.mu.Lock() + pending := b.pending + b.pending = map[int64]map[string]int64{} + b.mu.Unlock() + for gen, sizes := range pending { + if err := b.snapshot.UpdateSizes(context.Background(), gen, sizes); err != nil && b.logger != nil { + b.logger.Warn("snapshot size backfill failed", "repo", b.repoName, "generation", gen, "error", err) + } + } +} + func (s *Service) refreshCommitTime(ctx context.Context, cfg model.RepoConfig, oid string, resolver *fusefs.Resolver, warnMsg string) { if ts, err := s.git.CommitTimestamp(ctx, cfg, oid); err == nil { resolver.SetCommitTime(ts) @@ -568,18 +1170,28 @@ func (s *Service) startRuntime(rt *repoRuntime) { s.running[rt.cfg.ID] = rt s.mu.Unlock() + if rt.mfs != nil { + go func() { + _ = rt.mfs.Join(rt.ctx) + }() + } +} + +func (s *Service) startRepoBackground(rt *repoRuntime) { + s.mu.Lock() + if rt.active { + s.mu.Unlock() + return + } + rt.active = true + s.mu.Unlock() + go s.refreshLoop(rt) w := watcher.New(500 * time.Millisecond) go w.Watch(rt.ctx, rt.cfg.GitDir, func() { s.onHEADChanged(rt.ctx, rt) }) - - if rt.mfs != nil { - go func() { - _ = rt.mfs.Join(rt.ctx) - }() - } } func newRuntimeState(repoID model.RepoID, headOID string, headRef string, gen int64) model.RepoRuntimeState { @@ -589,7 +1201,7 @@ func newRuntimeState(repoID model.RepoID, headOID string, headRef string, gen in CurrentHEADRef: headRef, SnapshotGeneration: gen, LastFetchResult: "never", - State: "ready", + State: repoStateMounted, } } @@ -613,13 +1225,13 @@ func markFetchSuccess(st *model.RepoRuntimeState, at time.Time, state aheadBehin func markFetchResult(st *model.RepoRuntimeState, at time.Time, result string) { st.LastFetchResult = result st.LastFetchAt = at - if st.State == "degraded" && result == "ok" { - st.State = "ready" + if st.State == repoStateDegraded && result == "ok" { + st.State = repoStateMounted } } func markFetchFailure(st *model.RepoRuntimeState, result string) { - st.State = "degraded" + st.State = repoStateDegraded st.LastFetchResult = result } @@ -671,14 +1283,24 @@ func (s *Service) stopRuntime(rt *repoRuntime) { if rt.cancel != nil { rt.cancel() } + if rt.gate != nil { + rt.gate.MarkFailed(context.Canceled) + } if rt.mfs != nil { _ = rt.mfs.Unmount() } if rt.hydrator != nil { rt.hydrator.Stop() } - _ = rt.snapshot.Close() - _ = rt.overlay.Close() + if rt.sizes != nil { + rt.sizes.Stop() + } + if rt.snapshot != nil { + _ = rt.snapshot.Close() + } + if rt.overlay != nil { + _ = rt.overlay.Close() + } } func (s *Service) fillPaths(cfg *model.RepoConfig) { @@ -719,3 +1341,34 @@ func ParseRefresh(v string) (time.Duration, error) { } return d, nil } + +func isAsyncRepo(cfg model.RepoConfig) bool { + return cfg.PreparedGitDir || strings.TrimSpace(cfg.PrepareState) != "" +} + +func shouldMountAsync(cfg model.RepoConfig) bool { + return isPendingOrFailedPrepareState(cfg.PrepareState) +} + +func isPendingOrFailedPrepareState(state string) bool { + switch strings.TrimSpace(state) { + case model.PrepareStatePreparing, model.PrepareStateFailed: + return true + default: + return false + } +} + +func runtimeStateForPrepareState(state string) string { + if state == model.PrepareStateReady { + return repoStateMounted + } + return state +} + +func prepareGateError(msg string) error { + if strings.TrimSpace(msg) == "" { + return fusefs.ErrRepoNotReady + } + return errors.New(msg) +} diff --git a/internal/fusefs/fuse_unix.go b/internal/fusefs/fuse_unix.go index ef1b633..e5c4812 100644 --- a/internal/fusefs/fuse_unix.go +++ b/internal/fusefs/fuse_unix.go @@ -59,14 +59,21 @@ type DirHandle struct { } type FileHandle struct { - inode *InodeRef - path string + mu sync.Mutex + inode *InodeRef + path string + cacheFile *os.File + cacheGeneration int64 + invalidateSeq uint64 } -// ReaddirEntry holds a child name and type, avoiding per-child Getattr calls. +// ReaddirEntry holds child metadata, avoiding per-child Getattr or snapshot lookups. type ReaddirEntry struct { - Name string - Type string // file, dir, symlink + Name string + Type string // file, dir, symlink + ObjectOID string + SizeState string + SizeBytes int64 } func NewArtifactFuse(repo model.RepoConfig, resolver *Resolver, engine *Engine) *ArtifactFuse { @@ -147,6 +154,83 @@ func (fs *ArtifactFuse) fileHandle(handleID fuseops.HandleID) (*FileHandle, erro return fh, nil } +func (fs *ArtifactFuse) closeCachedFilesForPath(path string) { + fs.mu.RLock() + var handles []*FileHandle + for _, fh := range fs.fileHandles { + if fh.path == path { + handles = append(handles, fh) + } + } + fs.mu.RUnlock() + for _, fh := range handles { + fh.closeCachedFile() + } +} + +func (fh *FileHandle) read(ctx context.Context, engine *Engine, off int64, size int) ([]byte, error) { + currentGen := engine.Resolver.Generation() + fh.mu.Lock() + if fh.cacheFile != nil && fh.cacheGeneration == currentGen { + defer fh.mu.Unlock() + return readFileChunkFrom(fh.cacheFile, off, size) + } + if fh.cacheFile != nil { + f := fh.cacheFile + fh.cacheFile = nil + fh.cacheGeneration = 0 + fh.invalidateSeq++ + fh.mu.Unlock() + _ = f.Close() + fh.mu.Lock() + } + seq := fh.invalidateSeq + fh.mu.Unlock() + + cachePath, gen, ok, err := engine.BaseCachePath(ctx, fh.path) + if err != nil { + return nil, err + } + if !ok { + return engine.Read(ctx, fh.path, off, size) + } + f, err := os.Open(cachePath) + if err != nil { + return nil, err + } + + fh.mu.Lock() + if fh.invalidateSeq != seq || gen != engine.Resolver.Generation() { + fh.mu.Unlock() + _ = f.Close() + return engine.Read(ctx, fh.path, off, size) + } + if fh.cacheFile != nil && fh.cacheGeneration == gen { + _ = f.Close() + f = fh.cacheFile + } else { + if fh.cacheFile != nil { + _ = fh.cacheFile.Close() + } + fh.cacheFile = f + fh.cacheGeneration = gen + } + defer fh.mu.Unlock() + return readFileChunkFrom(f, off, size) +} + +func (fh *FileHandle) closeCachedFile() { + fh.mu.Lock() + f := fh.cacheFile + fh.cacheFile = nil + fh.cacheGeneration = 0 + fh.invalidateSeq++ + fh.mu.Unlock() + if f != nil { + _ = f.Close() + } +} + // --- FUSE operations --- func (fs *ArtifactFuse) StatFS(_ context.Context, op *fuseops.StatFSOp) error { @@ -206,6 +290,20 @@ func (fs *ArtifactFuse) GetInodeAttributes(_ context.Context, op *fuseops.GetIno return err } + if ref.IsRoot { + if fs.resolver != nil { + if mode, size, typ, mtime, ctime, err := fs.resolver.Getattr(ref.Path); err == nil { + op.Attributes = inodeAttrs(mode, uint64(size), typ, mtime, ctime) + op.AttributesExpiration = attrExpiry(time.Second) + return nil + } + } + now := time.Now() + op.Attributes = inodeAttrs(ref.Mode, 4096, "dir", now, now) + op.AttributesExpiration = attrExpiry(time.Second) + return nil + } + if ref.Path == ".git" { op.Attributes = fs.gitFileAttrs() op.AttributesExpiration = attrExpiry(time.Minute) @@ -227,18 +325,22 @@ func (fs *ArtifactFuse) SetInodeAttributes(ctx context.Context, op *fuseops.SetI return err } if op.Size != nil { + fs.closeCachedFilesForPath(ref.Path) if err := fs.engine.Truncate(ctx, ref.Path, int64(*op.Size)); err != nil { return syscall.EIO } + fs.closeCachedFilesForPath(ref.Path) } // Handle mtime updates (e.g., from touch) if op.Mtime != nil { + fs.closeCachedFilesForPath(ref.Path) if err := fs.engine.SetMtime(ctx, ref.Path, *op.Mtime); err != nil { if errors.Is(err, iofs.ErrInvalid) { return syscall.ENOTSUP } return syscall.EIO } + fs.closeCachedFilesForPath(ref.Path) } mode, size, typ, mtime, ctime, err := fs.resolver.Getattr(ref.Path) if err != nil { @@ -363,7 +465,7 @@ func (fs *ArtifactFuse) ReadFile(ctx context.Context, op *fuseops.ReadFileOp) er return nil } - data, err := fs.engine.Read(ctx, fh.path, op.Offset, int(op.Size)) + data, err := fh.read(ctx, fs.engine, op.Offset, int(op.Size)) if err != nil { if os.IsNotExist(err) { return syscall.ENOENT @@ -380,10 +482,12 @@ func (fs *ArtifactFuse) WriteFile(ctx context.Context, op *fuseops.WriteFileOp) if err != nil { return err } + fs.closeCachedFilesForPath(fh.path) _, err = fs.engine.Write(ctx, fh.path, op.Offset, op.Data) if err != nil { return syscall.EIO } + fs.closeCachedFilesForPath(fh.path) return nil } @@ -449,9 +553,11 @@ func (fs *ArtifactFuse) Unlink(ctx context.Context, op *fuseops.UnlinkOp) error if err != nil { return err } + fs.closeCachedFilesForPath(childPath) if err := fs.engine.Unlink(ctx, childPath); err != nil { return syscall.EIO } + fs.closeCachedFilesForPath(childPath) return nil } @@ -466,12 +572,16 @@ func (fs *ArtifactFuse) Rename(ctx context.Context, op *fuseops.RenameOp) error } oldPath := cleanChildPath(oldParent.Path, op.OldName) newPath := cleanChildPath(newParent.Path, op.NewName) + fs.closeCachedFilesForPath(oldPath) + fs.closeCachedFilesForPath(newPath) if err := fs.engine.Rename(ctx, oldPath, newPath); err != nil { if errors.Is(err, iofs.ErrInvalid) { return syscall.ENOTSUP } return syscall.EIO } + fs.closeCachedFilesForPath(oldPath) + fs.closeCachedFilesForPath(newPath) return nil } @@ -529,8 +639,12 @@ func (fs *ArtifactFuse) SyncFile(_ context.Context, _ *fuseops.SyncFileOp) error func (fs *ArtifactFuse) ReleaseFileHandle(_ context.Context, op *fuseops.ReleaseFileHandleOp) error { fs.mu.Lock() + fh := fs.fileHandles[op.Handle] delete(fs.fileHandles, op.Handle) fs.mu.Unlock() + if fh != nil { + fh.closeCachedFile() + } return nil } @@ -559,8 +673,12 @@ func (m *mountedFSWrapper) Unmount() error { } func MountRepo(repo model.RepoConfig, resolver *Resolver, engine *Engine) (MountedFS, error) { + return MountRepoWithGate(repo, resolver, engine, nil) +} + +func MountRepoWithGate(repo model.RepoConfig, resolver *Resolver, engine *Engine, gate *ReadyGate) (MountedFS, error) { fsint := NewArtifactFuse(repo, resolver, engine) - server := fuseutil.NewFileSystemServer(fsint) + server := fuseutil.NewFileSystemServer(NewGatedFileSystem(fsint, gate)) mountCfg := &fuse.MountConfig{ FSName: "artifact-fs:" + repo.Name, diff --git a/internal/fusefs/fuse_unix_test.go b/internal/fusefs/fuse_unix_test.go index 4df7438..a03d9eb 100644 --- a/internal/fusefs/fuse_unix_test.go +++ b/internal/fusefs/fuse_unix_test.go @@ -3,8 +3,12 @@ package fusefs import ( + "context" "testing" "time" + + "github.com/cloudflare/artifact-fs/internal/model" + "github.com/jacobsa/fuse/fuseops" ) func TestInodeAttrsPreservesSeparateTimes(t *testing.T) { @@ -46,3 +50,48 @@ func TestGitFileAttrsUsesOneTimestamp(t *testing.T) { t.Fatalf("expected .git attrs to use one timestamp: atime=%v mtime=%v ctime=%v", attr.Atime, attr.Mtime, attr.Ctime) } } + +func TestRootInodeAttributesDoNotRequireResolver(t *testing.T) { + fs := NewArtifactFuse(model.RepoConfig{Name: "repo", GitDir: "/tmp/repo.git"}, nil, nil) + op := &fuseops.GetInodeAttributesOp{Inode: fuseops.RootInodeID} + + if err := fs.GetInodeAttributes(context.Background(), op); err != nil { + t.Fatalf("GetInodeAttributes(root): %v", err) + } + if !op.Attributes.Mode.IsDir() { + t.Fatalf("root mode = %#o, want directory", op.Attributes.Mode) + } + if op.Attributes.Size == 0 { + t.Fatal("root size = 0, want non-zero placeholder size") + } +} + +func TestRootInodeAttributesUseStableResolverAttrsWhenReady(t *testing.T) { + resolver := &Resolver{ + Snapshot: &fakeSnapshot{nodes: map[string]model.BaseNode{ + ".": {Path: ".", Type: "dir", Mode: 0o755, SizeBytes: 4096}, + }}, + Overlay: &fakeOverlay{entries: map[string]model.OverlayEntry{}}, + } + resolver.SetGeneration(7) + resolver.SetCommitTime(1_700_000_000) + fs := NewArtifactFuse(model.RepoConfig{Name: "repo", GitDir: "/tmp/repo.git"}, resolver, nil) + + first := &fuseops.GetInodeAttributesOp{Inode: fuseops.RootInodeID} + if err := fs.GetInodeAttributes(context.Background(), first); err != nil { + t.Fatalf("first GetInodeAttributes(root): %v", err) + } + time.Sleep(2 * time.Millisecond) + second := &fuseops.GetInodeAttributesOp{Inode: fuseops.RootInodeID} + if err := fs.GetInodeAttributes(context.Background(), second); err != nil { + t.Fatalf("second GetInodeAttributes(root): %v", err) + } + + want := time.Unix(1_700_000_000, 0) + if !first.Attributes.Mtime.Equal(want) || !second.Attributes.Mtime.Equal(want) { + t.Fatalf("root mtime = %v then %v, want stable %v", first.Attributes.Mtime, second.Attributes.Mtime, want) + } + if !first.Attributes.Ctime.Equal(second.Attributes.Ctime) { + t.Fatalf("root ctime changed: %v then %v", first.Attributes.Ctime, second.Attributes.Ctime) + } +} diff --git a/internal/fusefs/gated_fs.go b/internal/fusefs/gated_fs.go new file mode 100644 index 0000000..93b0783 --- /dev/null +++ b/internal/fusefs/gated_fs.go @@ -0,0 +1,232 @@ +//go:build !windows + +package fusefs + +import ( + "context" + "syscall" + + "github.com/jacobsa/fuse/fuseops" + "github.com/jacobsa/fuse/fuseutil" +) + +type gatedFileSystem struct { + next fuseutil.FileSystem + gate *ReadyGate +} + +func NewGatedFileSystem(next fuseutil.FileSystem, gate *ReadyGate) fuseutil.FileSystem { + if gate == nil { + return next + } + return &gatedFileSystem{next: next, gate: gate} +} + +func (fs *gatedFileSystem) wait(ctx context.Context) error { + if err := fs.gate.Wait(ctx); err != nil { + return syscall.EIO + } + return nil +} + +func (fs *gatedFileSystem) StatFS(ctx context.Context, op *fuseops.StatFSOp) error { + return fs.next.StatFS(ctx, op) +} + +func (fs *gatedFileSystem) LookUpInode(ctx context.Context, op *fuseops.LookUpInodeOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.LookUpInode(ctx, op) +} + +func (fs *gatedFileSystem) GetInodeAttributes(ctx context.Context, op *fuseops.GetInodeAttributesOp) error { + if op.Inode != fuseops.RootInodeID { + if err := fs.wait(ctx); err != nil { + return err + } + } + return fs.next.GetInodeAttributes(ctx, op) +} + +func (fs *gatedFileSystem) SetInodeAttributes(ctx context.Context, op *fuseops.SetInodeAttributesOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.SetInodeAttributes(ctx, op) +} + +func (fs *gatedFileSystem) ForgetInode(ctx context.Context, op *fuseops.ForgetInodeOp) error { + return fs.next.ForgetInode(ctx, op) +} + +func (fs *gatedFileSystem) BatchForget(ctx context.Context, op *fuseops.BatchForgetOp) error { + return fs.next.BatchForget(ctx, op) +} + +func (fs *gatedFileSystem) MkDir(ctx context.Context, op *fuseops.MkDirOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.MkDir(ctx, op) +} + +func (fs *gatedFileSystem) MkNode(ctx context.Context, op *fuseops.MkNodeOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.MkNode(ctx, op) +} + +func (fs *gatedFileSystem) CreateFile(ctx context.Context, op *fuseops.CreateFileOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.CreateFile(ctx, op) +} + +func (fs *gatedFileSystem) CreateLink(ctx context.Context, op *fuseops.CreateLinkOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.CreateLink(ctx, op) +} + +func (fs *gatedFileSystem) CreateSymlink(ctx context.Context, op *fuseops.CreateSymlinkOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.CreateSymlink(ctx, op) +} + +func (fs *gatedFileSystem) Rename(ctx context.Context, op *fuseops.RenameOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.Rename(ctx, op) +} + +func (fs *gatedFileSystem) RmDir(ctx context.Context, op *fuseops.RmDirOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.RmDir(ctx, op) +} + +func (fs *gatedFileSystem) Unlink(ctx context.Context, op *fuseops.UnlinkOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.Unlink(ctx, op) +} + +func (fs *gatedFileSystem) OpenDir(ctx context.Context, op *fuseops.OpenDirOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.OpenDir(ctx, op) +} + +func (fs *gatedFileSystem) ReadDir(ctx context.Context, op *fuseops.ReadDirOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.ReadDir(ctx, op) +} + +func (fs *gatedFileSystem) ReadDirPlus(ctx context.Context, op *fuseops.ReadDirPlusOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.ReadDirPlus(ctx, op) +} + +func (fs *gatedFileSystem) ReleaseDirHandle(ctx context.Context, op *fuseops.ReleaseDirHandleOp) error { + return fs.next.ReleaseDirHandle(ctx, op) +} + +func (fs *gatedFileSystem) OpenFile(ctx context.Context, op *fuseops.OpenFileOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.OpenFile(ctx, op) +} + +func (fs *gatedFileSystem) ReadFile(ctx context.Context, op *fuseops.ReadFileOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.ReadFile(ctx, op) +} + +func (fs *gatedFileSystem) WriteFile(ctx context.Context, op *fuseops.WriteFileOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.WriteFile(ctx, op) +} + +func (fs *gatedFileSystem) SyncFile(ctx context.Context, op *fuseops.SyncFileOp) error { + return fs.next.SyncFile(ctx, op) +} + +func (fs *gatedFileSystem) FlushFile(ctx context.Context, op *fuseops.FlushFileOp) error { + return fs.next.FlushFile(ctx, op) +} + +func (fs *gatedFileSystem) ReleaseFileHandle(ctx context.Context, op *fuseops.ReleaseFileHandleOp) error { + return fs.next.ReleaseFileHandle(ctx, op) +} + +func (fs *gatedFileSystem) ReadSymlink(ctx context.Context, op *fuseops.ReadSymlinkOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.ReadSymlink(ctx, op) +} + +func (fs *gatedFileSystem) RemoveXattr(ctx context.Context, op *fuseops.RemoveXattrOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.RemoveXattr(ctx, op) +} + +func (fs *gatedFileSystem) GetXattr(ctx context.Context, op *fuseops.GetXattrOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.GetXattr(ctx, op) +} + +func (fs *gatedFileSystem) ListXattr(ctx context.Context, op *fuseops.ListXattrOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.ListXattr(ctx, op) +} + +func (fs *gatedFileSystem) SetXattr(ctx context.Context, op *fuseops.SetXattrOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.SetXattr(ctx, op) +} + +func (fs *gatedFileSystem) Fallocate(ctx context.Context, op *fuseops.FallocateOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.Fallocate(ctx, op) +} + +func (fs *gatedFileSystem) SyncFS(ctx context.Context, op *fuseops.SyncFSOp) error { + if err := fs.wait(ctx); err != nil { + return err + } + return fs.next.SyncFS(ctx, op) +} + +func (fs *gatedFileSystem) Destroy() { + fs.next.Destroy() +} diff --git a/internal/fusefs/gated_fs_test.go b/internal/fusefs/gated_fs_test.go new file mode 100644 index 0000000..123d667 --- /dev/null +++ b/internal/fusefs/gated_fs_test.go @@ -0,0 +1,68 @@ +//go:build !windows + +package fusefs + +import ( + "context" + "errors" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/jacobsa/fuse/fuseops" + "github.com/jacobsa/fuse/fuseutil" +) + +type recordingFS struct { + fuseutil.NotImplementedFileSystem + lookups atomic.Int32 +} + +func (fs *recordingFS) LookUpInode(context.Context, *fuseops.LookUpInodeOp) error { + fs.lookups.Add(1) + return nil +} + +func TestGatedFileSystemBlocksUntilReady(t *testing.T) { + next := &recordingFS{} + gate := NewReadyGate(false) + fs := NewGatedFileSystem(next, gate) + + done := make(chan error, 1) + go func() { + done <- fs.LookUpInode(context.Background(), &fuseops.LookUpInodeOp{}) + }() + + select { + case err := <-done: + t.Fatalf("LookUpInode returned before ready: %v", err) + case <-time.After(20 * time.Millisecond): + } + if got := next.lookups.Load(); got != 0 { + t.Fatalf("lookups before ready = %d, want 0", got) + } + + gate.MarkReady() + if err := <-done; err != nil { + t.Fatalf("LookUpInode after ready: %v", err) + } + if got := next.lookups.Load(); got != 1 { + t.Fatalf("lookups after ready = %d, want 1", got) + } +} + +func TestGatedFileSystemFailedGateReturnsEIO(t *testing.T) { + next := &recordingFS{} + gate := NewReadyGate(false) + gate.MarkFailed(errors.New("clone failed")) + fs := NewGatedFileSystem(next, gate) + + err := fs.LookUpInode(context.Background(), &fuseops.LookUpInodeOp{}) + if !errors.Is(err, syscall.EIO) { + t.Fatalf("LookUpInode error = %v, want EIO", err) + } + if got := next.lookups.Load(); got != 0 { + t.Fatalf("lookups after failed gate = %d, want 0", got) + } +} diff --git a/internal/fusefs/merged.go b/internal/fusefs/merged.go index eb09247..3fa8869 100644 --- a/internal/fusefs/merged.go +++ b/internal/fusefs/merged.go @@ -131,12 +131,13 @@ func (r *Resolver) ReaddirTyped(ctx context.Context, path string) ([]ReaddirEntr type entry struct { name string typ string + base model.BaseNode } set := map[string]entry{} for _, c := range children { name := filepath.Base(c.Path) if name != "." { - set[name] = entry{name: name, typ: c.Type} + set[name] = entry{name: name, typ: c.Type, base: c} } } ovEntries, err := r.Overlay.ListByPrefix(ctx, path) @@ -163,7 +164,7 @@ func (r *Resolver) ReaddirTyped(ctx context.Context, path string) ([]ReaddirEntr } out := make([]ReaddirEntry, 0, len(set)) for _, e := range set { - out = append(out, ReaddirEntry{Name: e.name, Type: e.typ}) + out = append(out, ReaddirEntry{Name: e.name, Type: e.typ, ObjectOID: e.base.ObjectOID, SizeState: e.base.SizeState, SizeBytes: e.base.SizeBytes}) } sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name }) return out, nil diff --git a/internal/fusefs/merged_test.go b/internal/fusefs/merged_test.go index 490b60e..f1d6ea7 100644 --- a/internal/fusefs/merged_test.go +++ b/internal/fusefs/merged_test.go @@ -483,7 +483,7 @@ func TestReaddirTypedReturnsTypes(t *testing.T) { kids: map[string][]model.BaseNode{ ".": { {Path: "dir", Type: "dir"}, - {Path: "file.txt", Type: "file"}, + {Path: "file.txt", Type: "file", ObjectOID: "abc123", SizeState: "known", SizeBytes: 42}, {Path: "link", Type: "symlink"}, }, }, @@ -501,6 +501,15 @@ func TestReaddirTypedReturnsTypes(t *testing.T) { if types["dir"] != "dir" || types["file.txt"] != "file" || types["link"] != "symlink" { t.Fatalf("wrong types: %v", types) } + for _, e := range entries { + if e.Name == "file.txt" { + if e.ObjectOID != "abc123" || e.SizeState != "known" || e.SizeBytes != 42 { + t.Fatalf("file metadata = %+v", e) + } + return + } + } + t.Fatal("file.txt entry not found") } func TestChildName(t *testing.T) { diff --git a/internal/fusefs/ops.go b/internal/fusefs/ops.go index f68af19..34cc31d 100644 --- a/internal/fusefs/ops.go +++ b/internal/fusefs/ops.go @@ -7,12 +7,15 @@ import ( "io/fs" "os" "path/filepath" + "sort" "time" "github.com/cloudflare/artifact-fs/internal/hydrator" "github.com/cloudflare/artifact-fs/internal/model" ) +const maxPrefetchTasksPerDir = 256 + type Engine struct { Resolver *Resolver Repo model.RepoConfig @@ -57,6 +60,26 @@ func (e *Engine) Read(ctx context.Context, path string, off int64, size int) ([] return readFileChunk(cachePath, off, size) } +func (e *Engine) BaseCachePath(ctx context.Context, path string) (string, int64, bool, error) { + path = model.CleanPath(path) + if ov, ok := e.Overlay.Get(path); ok { + if ov.IsDeleted() { + return "", 0, false, os.ErrNotExist + } + return "", 0, false, nil + } + gen := e.Resolver.Generation() + n, ok := e.Resolver.Snapshot.GetNode(gen, path) + if !ok { + return "", 0, false, fs.ErrNotExist + } + cachePath, _, err := e.Hydrator.EnsureHydrated(ctx, e.Repo, n) + if err != nil { + return "", 0, false, err + } + return cachePath, gen, true, nil +} + func (e *Engine) Write(ctx context.Context, path string, off int64, data []byte) (int, error) { if err := e.ensureOverlay(ctx, path); err != nil { if !errors.Is(err, fs.ErrNotExist) { @@ -178,26 +201,34 @@ func (e *Engine) Truncate(ctx context.Context, path string, size int64) error { // PrefetchDir enqueues file children of a directory for speculative hydration. // Called from OpenDir in a goroutine so it doesn't block the FUSE operation. func (e *Engine) PrefetchDir(dirPath string, entries []ReaddirEntry) { - gen := e.Resolver.Generation() + tasks := make([]model.HydrationTask, 0, len(entries)) for _, entry := range entries { - if entry.Type != "file" { + if entry.Type != "file" || entry.ObjectOID == "" { continue } childPath := model.CleanPath(filepath.Join(dirPath, entry.Name)) - n, ok := e.Resolver.Snapshot.GetNode(gen, childPath) - if !ok || n.ObjectOID == "" { - continue - } pri := hydrator.ClassifyPriority(childPath) - e.Hydrator.Enqueue(model.HydrationTask{ + tasks = append(tasks, model.HydrationTask{ RepoID: e.Repo.ID, Path: childPath, - ObjectOID: n.ObjectOID, + ObjectOID: entry.ObjectOID, + SizeState: entry.SizeState, + SizeBytes: entry.SizeBytes, Priority: pri, Reason: "prefetch", EnqueuedAt: time.Now(), }) } + if len(tasks) > maxPrefetchTasksPerDir { + sort.SliceStable(tasks, func(i, j int) bool { + if tasks[i].Priority == tasks[j].Priority { + return tasks[i].Path < tasks[j].Path + } + return tasks[i].Priority > tasks[j].Priority + }) + tasks = tasks[:maxPrefetchTasksPerDir] + } + e.Hydrator.EnqueueBatch(tasks) } func readFileChunk(path string, off int64, size int) ([]byte, error) { @@ -206,6 +237,10 @@ func readFileChunk(path string, off int64, size int) ([]byte, error) { return nil, err } defer f.Close() + return readFileChunkFrom(f, off, size) +} + +func readFileChunkFrom(f *os.File, off int64, size int) ([]byte, error) { if _, err := f.Seek(off, io.SeekStart); err != nil { return nil, err } diff --git a/internal/fusefs/ops_test.go b/internal/fusefs/ops_test.go new file mode 100644 index 0000000..734b658 --- /dev/null +++ b/internal/fusefs/ops_test.go @@ -0,0 +1,277 @@ +package fusefs + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/cloudflare/artifact-fs/internal/hydrator" + "github.com/cloudflare/artifact-fs/internal/model" +) + +type fakeBatchHydrator struct { + tasks []model.HydrationTask + calls int + path string + pathsByOID map[string]string +} + +type generationSnapshot struct { + nodes map[int64]map[string]model.BaseNode +} + +type blockingCopyOverlay struct { + *fakeOverlay + beforeCopy chan struct{} + continueCopy chan struct{} + backingPath string +} + +func (f *fakeBatchHydrator) Enqueue(task model.HydrationTask) { + f.tasks = append(f.tasks, task) +} + +func (f *fakeBatchHydrator) EnqueueBatch(tasks []model.HydrationTask) { + f.tasks = append(f.tasks, tasks...) +} + +func (f *fakeBatchHydrator) EnsureHydrated(_ context.Context, _ model.RepoConfig, node model.BaseNode) (string, int64, error) { + f.calls++ + if f.pathsByOID != nil { + return f.pathsByOID[node.ObjectOID], 0, nil + } + return f.path, 0, nil +} + +func (f *fakeBatchHydrator) ReadBlob(context.Context, model.RepoConfig, model.BaseNode, int64) ([]byte, error) { + return nil, nil +} + +func (f *fakeBatchHydrator) QueueDepth(model.RepoID) int { return len(f.tasks) } + +func (g *generationSnapshot) PublishGeneration(context.Context, string, string, []model.BaseNode) (int64, error) { + return 0, nil +} + +func (g *generationSnapshot) GetNode(gen int64, path string) (model.BaseNode, bool) { + n, ok := g.nodes[gen][path] + return n, ok +} + +func (g *generationSnapshot) ListChildren(int64, string) ([]model.BaseNode, error) { + return nil, nil +} + +func (o *blockingCopyOverlay) EnsureCopyOnWrite(_ context.Context, _ model.RepoConfig, path string, base model.BaseNode) (model.OverlayEntry, error) { + close(o.beforeCopy) + <-o.continueCopy + now := time.Now().UnixNano() + e := model.OverlayEntry{Path: model.CleanPath(path), Kind: model.OverlayKindModify, BackingPath: o.backingPath, Mode: base.Mode, MtimeUnixNs: now, CtimeUnixNs: now, SourceOID: base.ObjectOID} + o.entries[e.Path] = e + return e, nil +} + +func (o *blockingCopyOverlay) WriteFile(_ context.Context, path string, off int64, data []byte) (int, error) { + f, err := os.OpenFile(o.backingPath, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + return 0, err + } + defer f.Close() + if _, err := f.WriteAt(data, off); err != nil { + return 0, err + } + e := o.entries[model.CleanPath(path)] + if info, err := f.Stat(); err == nil { + e.SizeBytes = info.Size() + } + o.entries[model.CleanPath(path)] = e + return len(data), nil +} + +func TestPrefetchDirBatchesReaddirMetadata(t *testing.T) { + hydrator := &fakeBatchHydrator{} + engine := &Engine{Repo: model.RepoConfig{ID: "repo"}, Hydrator: hydrator} + + engine.PrefetchDir("src", []ReaddirEntry{ + {Name: "a.go", Type: "file", ObjectOID: "a", SizeState: "known", SizeBytes: 10}, + {Name: "sub", Type: "dir"}, + {Name: "overlay.txt", Type: "file"}, + {Name: "b.go", Type: "file", ObjectOID: "b", SizeState: "unknown"}, + }) + + if len(hydrator.tasks) != 2 { + t.Fatalf("tasks = %d, want 2", len(hydrator.tasks)) + } + if hydrator.tasks[0].Path != "src/a.go" || hydrator.tasks[0].ObjectOID != "a" || hydrator.tasks[0].SizeBytes != 10 { + t.Fatalf("first task = %+v", hydrator.tasks[0]) + } + if hydrator.tasks[1].Path != "src/b.go" || hydrator.tasks[1].ObjectOID != "b" || hydrator.tasks[1].SizeState != "unknown" { + t.Fatalf("second task = %+v", hydrator.tasks[1]) + } +} + +func TestPrefetchDirCapsAndPrioritizesTasks(t *testing.T) { + h := &fakeBatchHydrator{} + engine := &Engine{Repo: model.RepoConfig{ID: "repo"}, Hydrator: h} + entries := make([]ReaddirEntry, 0, maxPrefetchTasksPerDir+2) + for i := 0; i < maxPrefetchTasksPerDir+1; i++ { + entries = append(entries, ReaddirEntry{Name: fmt.Sprintf("image-%03d.png", i), Type: "file", ObjectOID: fmt.Sprintf("png-%03d", i)}) + } + entries = append(entries, ReaddirEntry{Name: "README.md", Type: "file", ObjectOID: "readme"}) + + engine.PrefetchDir(".", entries) + + if len(h.tasks) != maxPrefetchTasksPerDir { + t.Fatalf("tasks = %d, want %d", len(h.tasks), maxPrefetchTasksPerDir) + } + foundReadme := false + for _, task := range h.tasks { + if task.ObjectOID == "readme" { + foundReadme = true + if task.Priority < hydrator.PriorityBootstrap { + t.Fatalf("README priority = %d", task.Priority) + } + } + } + if !foundReadme { + t.Fatal("README.md was dropped from capped prefetch") + } +} + +func TestFileHandleCachesHydratedBaseFile(t *testing.T) { + tmp := t.TempDir() + cachePath := filepath.Join(tmp, "blob") + if err := os.WriteFile(cachePath, []byte("content"), 0o644); err != nil { + t.Fatal(err) + } + h := &fakeBatchHydrator{path: cachePath} + overlay := &fakeOverlay{entries: map[string]model.OverlayEntry{}} + engine := &Engine{ + Repo: model.RepoConfig{ID: "repo"}, + Resolver: newResolver(&fakeSnapshot{nodes: map[string]model.BaseNode{"file.txt": {Path: "file.txt", Type: "file", ObjectOID: "blob"}}}, overlay), + Overlay: overlay, + Hydrator: h, + } + fh := &FileHandle{path: "file.txt"} + defer fh.closeCachedFile() + + first, err := fh.read(context.Background(), engine, 0, 4) + if err != nil { + t.Fatal(err) + } + second, err := fh.read(context.Background(), engine, 4, 3) + if err != nil { + t.Fatal(err) + } + if string(first) != "cont" || string(second) != "ent" { + t.Fatalf("reads = %q/%q", first, second) + } + if h.calls != 1 { + t.Fatalf("EnsureHydrated calls = %d, want 1", h.calls) + } +} + +func TestFileHandleRehydratesAfterGenerationChange(t *testing.T) { + tmp := t.TempDir() + firstCache := filepath.Join(tmp, "first") + secondCache := filepath.Join(tmp, "second") + if err := os.WriteFile(firstCache, []byte("old"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(secondCache, []byte("new"), 0o644); err != nil { + t.Fatal(err) + } + h := &fakeBatchHydrator{pathsByOID: map[string]string{"first": firstCache, "second": secondCache}} + overlay := &fakeOverlay{entries: map[string]model.OverlayEntry{}} + resolver := &Resolver{ + Snapshot: &generationSnapshot{nodes: map[int64]map[string]model.BaseNode{ + 1: {"file.txt": {Path: "file.txt", Type: "file", ObjectOID: "first"}}, + 2: {"file.txt": {Path: "file.txt", Type: "file", ObjectOID: "second"}}, + }}, + Overlay: overlay, + } + resolver.SetGeneration(1) + engine := &Engine{Repo: model.RepoConfig{ID: "repo"}, Resolver: resolver, Overlay: overlay, Hydrator: h} + fh := &FileHandle{path: "file.txt"} + defer fh.closeCachedFile() + + first, err := fh.read(context.Background(), engine, 0, 3) + if err != nil { + t.Fatal(err) + } + resolver.SetGeneration(2) + second, err := fh.read(context.Background(), engine, 0, 3) + if err != nil { + t.Fatal(err) + } + if string(first) != "old" || string(second) != "new" { + t.Fatalf("reads = %q/%q", first, second) + } + if h.calls != 2 { + t.Fatalf("EnsureHydrated calls = %d, want 2", h.calls) + } +} + +func TestFileHandleInvalidatesCacheAfterOverlappingWrite(t *testing.T) { + tmp := t.TempDir() + baseCache := filepath.Join(tmp, "base") + overlayBacking := filepath.Join(tmp, "overlay") + if err := os.WriteFile(baseCache, []byte("old"), 0o644); err != nil { + t.Fatal(err) + } + h := &fakeBatchHydrator{path: baseCache} + overlay := &blockingCopyOverlay{ + fakeOverlay: &fakeOverlay{entries: map[string]model.OverlayEntry{}}, + beforeCopy: make(chan struct{}), + continueCopy: make(chan struct{}), + backingPath: overlayBacking, + } + resolver := newResolver(&fakeSnapshot{nodes: map[string]model.BaseNode{"file.txt": {Path: "file.txt", Type: "file", ObjectOID: "blob", Mode: 0o644}}}, overlay.fakeOverlay) + engine := &Engine{Repo: model.RepoConfig{ID: "repo"}, Resolver: resolver, Overlay: overlay, Hydrator: h} + fs := NewArtifactFuse(model.RepoConfig{ID: "repo"}, resolver, engine) + fh := &FileHandle{path: "file.txt"} + fs.fileHandles[1] = fh + defer fh.closeCachedFile() + + first, err := fh.read(context.Background(), engine, 0, 3) + if err != nil { + t.Fatal(err) + } + if string(first) != "old" { + t.Fatalf("first read = %q", first) + } + + writeDone := make(chan error, 1) + go func() { + fs.closeCachedFilesForPath("file.txt") + _, err := engine.Write(context.Background(), "file.txt", 0, []byte("new")) + if err == nil { + fs.closeCachedFilesForPath("file.txt") + } + writeDone <- err + }() + + <-overlay.beforeCopy + overlapped, err := fh.read(context.Background(), engine, 0, 3) + if err != nil { + t.Fatal(err) + } + if string(overlapped) != "old" { + t.Fatalf("overlapped read = %q", overlapped) + } + close(overlay.continueCopy) + if err := <-writeDone; err != nil { + t.Fatal(err) + } + + after, err := fh.read(context.Background(), engine, 0, 3) + if err != nil { + t.Fatal(err) + } + if string(after) != "new" { + t.Fatalf("read after write = %q, want new", after) + } +} diff --git a/internal/fusefs/readsymlink_unix_test.go b/internal/fusefs/readsymlink_unix_test.go index 23e9580..453697b 100644 --- a/internal/fusefs/readsymlink_unix_test.go +++ b/internal/fusefs/readsymlink_unix_test.go @@ -24,6 +24,8 @@ type fakeSymlinkHydrator struct { func (f *fakeSymlinkHydrator) Enqueue(model.HydrationTask) {} +func (f *fakeSymlinkHydrator) EnqueueBatch([]model.HydrationTask) {} + func (f *fakeSymlinkHydrator) EnsureHydrated(_ context.Context, _ model.RepoConfig, _ model.BaseNode) (string, int64, error) { f.calls++ return f.cachePath, f.size, f.err diff --git a/internal/fusefs/ready_gate.go b/internal/fusefs/ready_gate.go new file mode 100644 index 0000000..229fada --- /dev/null +++ b/internal/fusefs/ready_gate.go @@ -0,0 +1,115 @@ +package fusefs + +import ( + "context" + "errors" + "sync" +) + +var ErrRepoNotReady = errors.New("repo is not ready") + +type ReadyGate struct { + mu sync.Mutex + ready bool + err error + done chan struct{} + closed bool +} + +func NewReadyGate(ready bool) *ReadyGate { + g := &ReadyGate{ready: ready, done: make(chan struct{})} + if ready { + close(g.done) + g.closed = true + } + return g +} + +func (g *ReadyGate) Wait(ctx context.Context) error { + if g == nil { + return nil + } + g.mu.Lock() + if g.ready { + g.mu.Unlock() + return nil + } + if g.err != nil { + err := g.err + g.mu.Unlock() + return err + } + done := g.done + g.mu.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + } + + g.mu.Lock() + defer g.mu.Unlock() + if g.ready { + return nil + } + if g.err != nil { + return g.err + } + return ErrRepoNotReady +} + +func (g *ReadyGate) Reset() { + if g == nil { + return + } + g.mu.Lock() + defer g.mu.Unlock() + if !g.ready && g.err == nil { + return + } + g.ready = false + g.err = nil + g.done = make(chan struct{}) + g.closed = false +} + +func (g *ReadyGate) MarkReady() { + if g == nil { + return + } + g.mu.Lock() + defer g.mu.Unlock() + if g.ready { + return + } + g.ready = true + g.err = nil + if !g.closed { + close(g.done) + g.closed = true + } +} + +func (g *ReadyGate) MarkFailed(err error) { + if g == nil { + return + } + if err == nil { + err = ErrRepoNotReady + } + g.mu.Lock() + defer g.mu.Unlock() + if g.err != nil { + g.err = err + return + } + if g.ready { + return + } + g.err = err + if !g.closed { + close(g.done) + g.closed = true + } +} diff --git a/internal/fusefs/ready_gate_test.go b/internal/fusefs/ready_gate_test.go new file mode 100644 index 0000000..4800cdc --- /dev/null +++ b/internal/fusefs/ready_gate_test.go @@ -0,0 +1,51 @@ +package fusefs + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestReadyGateWaitsUntilReady(t *testing.T) { + g := NewReadyGate(false) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- g.Wait(ctx) + }() + + select { + case err := <-done: + t.Fatalf("Wait returned before ready: %v", err) + case <-time.After(20 * time.Millisecond): + } + + g.MarkReady() + if err := <-done; err != nil { + t.Fatalf("Wait returned error after ready: %v", err) + } +} + +func TestReadyGateFailureFailsFastAndCanReset(t *testing.T) { + g := NewReadyGate(false) + g.MarkFailed(errors.New("clone failed")) + + if err := g.Wait(context.Background()); err == nil || err.Error() != "clone failed" { + t.Fatalf("Wait error = %v, want clone failed", err) + } + + g.Reset() + g.MarkReady() + if err := g.Wait(context.Background()); err != nil { + t.Fatalf("Wait after reset/ready: %v", err) + } +} + +func TestReadyGateMarkReadyAfterFailureDoesNotPanic(t *testing.T) { + g := NewReadyGate(false) + g.MarkFailed(errors.New("clone failed")) + g.MarkReady() +} diff --git a/internal/gitstore/gitstore.go b/internal/gitstore/gitstore.go index 5c4418f..dc1683b 100644 --- a/internal/gitstore/gitstore.go +++ b/internal/gitstore/gitstore.go @@ -32,8 +32,23 @@ type readBlobResult struct { err error } +type fetchBlobResult struct { + size int64 + err error +} + const maxReadBlobBytes int64 = 1<<31 - 1 +const fetchedFullRefRemoteTrackingRef = "refs/remotes/artifact-fs/fetch-ref" + +const zeroOID = "0000000000000000000000000000000000000000" + +type fetchRefInfo struct { + sourceRef string + remoteRef string + branch string +} + func New(logger *slog.Logger) *Store { if logger == nil { logger = slog.Default() @@ -64,6 +79,14 @@ func (s *Store) SetBatchPoolSize(n int) { } func (s *Store) CloneBlobless(ctx context.Context, cfg model.RepoConfig) error { + return s.cloneBlobless(ctx, cfg, nil) +} + +func (s *Store) CloneBloblessNonInteractive(ctx context.Context, cfg model.RepoConfig) error { + return s.cloneBlobless(ctx, cfg, nonInteractiveGitEnv()) +} + +func (s *Store) cloneBlobless(ctx context.Context, cfg model.RepoConfig, extraEnv []string) error { if _, err := os.Stat(cfg.GitDir); err == nil { return nil } @@ -80,27 +103,151 @@ func (s *Store) CloneBlobless(ctx context.Context, cfg model.RepoConfig) error { // Strip credentials from the CLI-visible URL; pass them via a credential helper // so they don't appear in ps output. - safeURL, credHelper := credentialEnv(cfg.RemoteURL) - - args := []string{"clone", "--filter=blob:none", "--no-checkout", "--single-branch", "--branch", cfg.Branch, safeURL, target} - if _, err := runGitWithEnv(ctx, "", credHelper, args...); err != nil { + safeURL, credHelper, err := credentialEnv(cfg.RemoteURL) + if err != nil { return err } - if err := os.Rename(filepath.Join(target, ".git"), cfg.GitDir); err != nil { + env := append([]string{}, extraEnv...) + env = append(env, credHelper...) + + args := []string{"clone", "--filter=blob:none", "--no-checkout", "--single-branch", "--no-tags", "--branch", cfg.Branch, safeURL, target} + if _, err := runGitWithEnv(ctx, "", env, args...); err != nil { return err } // Populate the index so git status works inside the mount. - if _, err := runGit(ctx, cfg.GitDir, "read-tree", "HEAD"); err != nil { + if _, err := runGit(ctx, filepath.Join(target, ".git"), "read-tree", "HEAD"); err != nil { + return err + } + if err := os.Rename(filepath.Join(target, ".git"), cfg.GitDir); err != nil { return err } return nil } func (s *Store) Fetch(ctx context.Context, repo model.RepoConfig) error { - _, err := runGit(ctx, repo.GitDir, "fetch", "origin") + _, err := runGit(ctx, repo.GitDir, "fetch", "--no-tags", "origin") + return err +} + +func (s *Store) FetchRefNonInteractive(ctx context.Context, repo model.RepoConfig, ref string) error { + target, err := fetchRefTarget(repo, ref) + if err != nil { + return err + } + refspec := "+" + target.sourceRef + ":" + target.remoteRef + if target.branch != "" { + refspec = fmt.Sprintf("+refs/heads/%s:refs/remotes/origin/%s", target.branch, target.branch) + } + _, err = runGitWithEnv(ctx, repo.GitDir, nonInteractiveGitEnv(), "fetch", "--filter=blob:none", "--no-tags", "origin", refspec) return err } +func (s *Store) PrepareExistingCloneNonInteractive(ctx context.Context, repo model.RepoConfig) error { + if err := s.ValidateAmbientRemote(repo); err != nil { + return err + } + remoteURL, err := runGit(ctx, repo.GitDir, "remote", "get-url", "origin") + if err != nil { + return err + } + if strings.TrimSpace(remoteURL) != strings.TrimSpace(repo.RemoteURL) { + if _, err := runGitWithEnv(ctx, repo.GitDir, nonInteractiveGitEnv(), "remote", "set-url", "origin", repo.RemoteURL); err != nil { + return err + } + } + if err := s.FetchRefNonInteractive(ctx, repo, repo.FetchRef); err != nil { + return err + } + return s.PrepareFetchedBranch(ctx, repo, repo.FetchRef) +} + +func (s *Store) ValidateAmbientRemote(repo model.RepoConfig) error { + if strings.TrimSpace(repo.RemoteURL) == "" { + return errors.New("remote URL is required") + } + safeURL, _, err := credentialEnv(repo.RemoteURL) + if err != nil { + return err + } + if safeURL != repo.RemoteURL { + return errors.New("remote must use ambient credentials") + } + return nil +} + +func (s *Store) PrepareFetchedBranch(ctx context.Context, repo model.RepoConfig, ref string) error { + target, err := fetchRefTarget(repo, ref) + if err != nil { + return err + } + oid, err := runGit(ctx, repo.GitDir, "rev-parse", "--verify", target.remoteRef+"^{commit}") + if err != nil { + return fmt.Errorf("remote ref %s missing after fetch: %w", target.remoteRef, err) + } + oid = strings.TrimSpace(oid) + if target.branch == "" { + if _, err := runGit(ctx, repo.GitDir, "update-ref", "--no-deref", "HEAD", oid); err != nil { + return err + } + return s.ReadTreeHEAD(ctx, repo) + } + refName := "refs/heads/" + target.branch + if repo.PreparedGitDir { + oldOID, err := s.preparedBranchExpectedOID(ctx, repo, target.branch, oid) + if err != nil { + return err + } + if _, err := runGit(ctx, repo.GitDir, "update-ref", refName, oid, oldOID); err != nil { + return err + } + } else if _, err := runGit(ctx, repo.GitDir, "update-ref", refName, oid); err != nil { + return err + } + if _, err := runGit(ctx, repo.GitDir, "symbolic-ref", "HEAD", "refs/heads/"+target.branch); err != nil { + return err + } + if _, err := runGit(ctx, repo.GitDir, "branch", "--set-upstream-to", "origin/"+target.branch, target.branch); err != nil { + s.logger.Warn("set upstream failed", "repo", repo.Name, "error", err) + } + return s.ReadTreeHEAD(ctx, repo) +} + +func (s *Store) preparedBranchExpectedOID(ctx context.Context, repo model.RepoConfig, branch string, oid string) (string, error) { + current, err := runGit(ctx, repo.GitDir, "rev-parse", "--verify", "refs/heads/"+branch+"^{commit}") + if err != nil { + return zeroOID, nil + } + current = strings.TrimSpace(current) + if current == oid { + return current, nil + } + if _, err := runGit(ctx, repo.GitDir, "merge-base", "--is-ancestor", current, oid); err != nil { + return "", fmt.Errorf("prepared git dir branch %s would be overwritten; refusing non-fast-forward update", branch) + } + return current, nil +} + +func (s *Store) ValidatePreparedGitDir(ctx context.Context, repo model.RepoConfig) error { + if strings.TrimSpace(repo.GitDir) == "" { + return errors.New("git dir is required") + } + st, err := os.Stat(repo.GitDir) + if err != nil { + return err + } + if !st.IsDir() { + return fmt.Errorf("git dir %s is not a directory", repo.GitDir) + } + if _, err := runGit(ctx, repo.GitDir, "rev-parse", "--git-dir"); err != nil { + return err + } + remoteURL, err := runGit(ctx, repo.GitDir, "remote", "get-url", "origin") + if err == nil && remoteHasInlineCredentials(remoteURL) { + return errors.New("prepared git dir origin must use ambient credentials") + } + return nil +} + func (s *Store) ResolveHEAD(ctx context.Context, repo model.RepoConfig) (oid string, ref string, err error) { oid, err = runGit(ctx, repo.GitDir, "rev-parse", "HEAD") if err != nil { @@ -116,55 +263,24 @@ func (s *Store) ResolveHEAD(ctx context.Context, repo model.RepoConfig) (oid str func (s *Store) BuildTreeIndex(ctx context.Context, repo model.RepoConfig, headOID string) ([]model.BaseNode, error) { // -z: NUL-delimited output with raw paths (no C-quoting of non-ASCII names). - out, err := runGit(ctx, repo.GitDir, "ls-tree", "-r", "-t", "-z", headOID) - if err != nil { - return nil, err - } - records := strings.Split(out, "\x00") nodes := []model.BaseNode{rootNode(repo.ID)} var blobOIDs []string blobIndex := map[string][]int{} // oid -> indices into nodes - for _, line := range records { - if line == "" { - continue - } - parts := strings.SplitN(line, "\t", 2) - if len(parts) != 2 { - continue - } - meta := strings.Fields(parts[0]) - if len(meta) < 3 { - continue - } - modeStr := meta[0] - typ := meta[1] - oid := meta[2] - path := parts[1] - mode64, _ := strconv.ParseUint(modeStr, 8, 32) - mode := uint32(mode64) - - nodeType := normalizeGitType(typ, mode) - if typ == "commit" { - continue - } - - n := model.BaseNode{ - RepoID: repo.ID, - Path: path, - Type: nodeType, - Mode: mode, - ObjectOID: oid, - SizeState: "unknown", - SizeBytes: 0, + if err := streamTreeRecords(ctx, repo.GitDir, headOID, func(line string) { + n, typ, ok := parseTreeRecord(repo.ID, line) + if !ok { + return } idx := len(nodes) nodes = append(nodes, n) - if typ == "blob" && oid != "" { - blobIndex[oid] = append(blobIndex[oid], idx) - if len(blobIndex[oid]) == 1 { - blobOIDs = append(blobOIDs, oid) + if typ == "blob" && n.ObjectOID != "" { + blobIndex[n.ObjectOID] = append(blobIndex[n.ObjectOID], idx) + if len(blobIndex[n.ObjectOID]) == 1 { + blobOIDs = append(blobOIDs, n.ObjectOID) } } + }); err != nil { + return nil, err } // Batch-resolve sizes using cat-file --batch-check. This reads from local @@ -177,6 +293,80 @@ func (s *Store) BuildTreeIndex(ctx context.Context, repo model.RepoConfig, headO return addImplicitDirs(repo.ID, nodes), nil } +func streamTreeRecords(ctx context.Context, gitDir string, headOID string, fn func(string)) error { + cmd := exec.CommandContext(ctx, "git", "ls-tree", "-r", "-t", "-z", headOID) + cmd.Env = append(os.Environ(), "GIT_DIR="+gitDir) + stdout, err := cmd.StdoutPipe() + if err != nil { + return err + } + errBuf := &bytes.Buffer{} + cmd.Stderr = errBuf + if err := cmd.Start(); err != nil { + return err + } + readErr := readNullDelimited(stdout, fn) + waitErr := cmd.Wait() + if readErr != nil { + return readErr + } + if waitErr != nil { + msg := auth.RedactString(strings.TrimSpace(errBuf.String())) + if msg == "" { + msg = auth.RedactString(waitErr.Error()) + } + return errors.New(msg) + } + return nil +} + +func readNullDelimited(r io.Reader, fn func(string)) error { + reader := bufio.NewReader(r) + for { + record, err := reader.ReadString('\x00') + if record != "" { + record = strings.TrimSuffix(record, "\x00") + if record != "" { + fn(record) + } + } + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + } +} + +func parseTreeRecord(repoID model.RepoID, line string) (model.BaseNode, string, bool) { + parts := strings.SplitN(line, "\t", 2) + if len(parts) != 2 { + return model.BaseNode{}, "", false + } + meta := strings.Fields(parts[0]) + if len(meta) < 3 { + return model.BaseNode{}, "", false + } + modeStr := meta[0] + typ := meta[1] + oid := meta[2] + mode64, _ := strconv.ParseUint(modeStr, 8, 32) + mode := uint32(mode64) + if typ == "commit" { + return model.BaseNode{}, typ, false + } + return model.BaseNode{ + RepoID: repoID, + Path: parts[1], + Type: normalizeGitType(typ, mode), + Mode: mode, + ObjectOID: oid, + SizeState: "unknown", + SizeBytes: 0, + }, typ, true +} + func (s *Store) batchResolveSizes(ctx context.Context, repo model.RepoConfig, nodes []model.BaseNode, oids []string, index map[string][]int) error { if len(oids) == 0 { return nil @@ -192,44 +382,76 @@ func (s *Store) batchResolveSizes(ctx context.Context, repo model.RepoConfig, no if err != nil { return err } - var outBuf bytes.Buffer - cmd.Stdout = &outBuf - cmd.Stderr = &bytes.Buffer{} - if err := cmd.Start(); err != nil { + stdout, err := cmd.StdoutPipe() + if err != nil { return err } - for _, oid := range oids { - fmt.Fprintln(stdin, oid) - } - stdin.Close() - if err := cmd.Wait(); err != nil { + errBuf := &bytes.Buffer{} + cmd.Stderr = errBuf + if err := cmd.Start(); err != nil { return err } - // Output format: " " or " missing" - scan := bufio.NewScanner(&outBuf) - for scan.Scan() { - fields := strings.Fields(scan.Text()) - if len(fields) < 3 { - continue + writeErrCh := make(chan error, 1) + go func() { + var writeErr error + for _, oid := range oids { + if _, writeErr = fmt.Fprintln(stdin, oid); writeErr != nil { + break + } } - oid := fields[0] - sizeStr := fields[2] - sz, err := strconv.ParseInt(sizeStr, 10, 64) - if err != nil { - continue + if closeErr := stdin.Close(); writeErr == nil { + writeErr = closeErr } - for _, idx := range index[oid] { - nodes[idx].SizeBytes = sz - nodes[idx].SizeState = "known" + writeErrCh <- writeErr + }() + // Output format: " " or " missing" + scan := bufio.NewScanner(stdout) + for scan.Scan() { + applyBatchCheckLine(nodes, index, scan.Text()) + } + scanErr := scan.Err() + writeErr := <-writeErrCh + waitErr := cmd.Wait() + if writeErr != nil { + return writeErr + } + if scanErr != nil { + return scanErr + } + if waitErr != nil { + msg := auth.RedactString(strings.TrimSpace(errBuf.String())) + if msg == "" { + msg = auth.RedactString(waitErr.Error()) } + return errors.New(msg) + } + return nil +} + +func applyBatchCheckLine(nodes []model.BaseNode, index map[string][]int, line string) { + fields := strings.Fields(line) + if len(fields) < 3 { + return + } + oid := fields[0] + sizeStr := fields[2] + sz, err := strconv.ParseInt(sizeStr, 10, 64) + if err != nil { + return + } + for _, idx := range index[oid] { + nodes[idx].SizeBytes = sz + nodes[idx].SizeState = "known" } - return scan.Err() } // BlobToCache fetches a git object and writes it to dstPath in a binary-safe manner. // Uses a persistent cat-file --batch process to amortize process spawn and // remote connection costs across multiple blob fetches. func (s *Store) BlobToCache(ctx context.Context, repo model.RepoConfig, objectOID string, dstPath string) (size int64, err error) { + if err := ctx.Err(); err != nil { + return 0, err + } if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil { return 0, err } @@ -238,16 +460,22 @@ func (s *Store) BlobToCache(ctx context.Context, repo model.RepoConfig, objectOI if err != nil { return 0, err } - size, err = batch.fetchToFile(objectOID, dstPath) + size, err = fetchBatchToFile(ctx, batch, objectOID, dstPath) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return 0, err + } // Process may have died or be desynchronized; discard and retry. batch.close() batch, err = pool.acquire() if err != nil { return 0, err } - size, err = batch.fetchToFile(objectOID, dstPath) + size, err = fetchBatchToFile(ctx, batch, objectOID, dstPath) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return 0, err + } // Retry also failed; close instead of returning a potentially // corrupted process to the pool. batch.close() @@ -258,6 +486,22 @@ func (s *Store) BlobToCache(ctx context.Context, repo model.RepoConfig, objectOI return size, err } +func fetchBatchToFile(ctx context.Context, batch *batchCatFile, objectOID string, dstPath string) (int64, error) { + ch := make(chan fetchBlobResult, 1) + go func() { + size, err := batch.fetchToFile(objectOID, dstPath) + ch <- fetchBlobResult{size: size, err: err} + }() + select { + case r := <-ch: + return r.size, r.err + case <-ctx.Done(): + batch.kill() + <-ch + return 0, ctx.Err() + } +} + func (s *Store) ReadBlob(ctx context.Context, repo model.RepoConfig, objectOID string, maxBytes int64) ([]byte, error) { if maxBytes < 0 { return nil, fmt.Errorf("negative max bytes: %d", maxBytes) @@ -398,16 +642,18 @@ func (p *batchPool) setMaxSize(n int) { // remote connection costs across multiple blob fetches. Callers must ensure // exclusive access (the batchPool handles this). type batchCatFile struct { - cmd *exec.Cmd - stdin io.WriteCloser - stdout *bufio.Reader - logger *slog.Logger + cmd *exec.Cmd + stdin io.WriteCloser + stdoutPipe io.ReadCloser + stdout *bufio.Reader + logger *slog.Logger } func newBatchCatFile(gitDir string, logger *slog.Logger) (*batchCatFile, error) { cmd := exec.Command("git", "cat-file", "--batch") cmd.Env = append(os.Environ(), "GIT_DIR="+gitDir) cmd.Stderr = os.Stderr + configureBatchCommand(cmd) stdin, err := cmd.StdinPipe() if err != nil { @@ -421,10 +667,11 @@ func newBatchCatFile(gitDir string, logger *slog.Logger) (*batchCatFile, error) return nil, fmt.Errorf("batch cat-file start: %w", err) } return &batchCatFile{ - cmd: cmd, - stdin: stdin, - stdout: bufio.NewReaderSize(stdout, 256*1024), - logger: logger, + cmd: cmd, + stdin: stdin, + stdoutPipe: stdout, + stdout: bufio.NewReaderSize(stdout, 256*1024), + logger: logger, }, nil } @@ -442,9 +689,10 @@ func (b *batchCatFile) close() { } func (b *batchCatFile) kill() { - if b.cmd != nil && b.cmd.Process != nil { - _ = b.cmd.Process.Kill() + if b.stdoutPipe != nil { + _ = b.stdoutPipe.Close() } + killBatchCommand(b.cmd) b.close() } @@ -626,41 +874,366 @@ func runGitWithEnv(ctx context.Context, gitDir string, extraEnv []string, args . // credentialEnv returns a sanitized URL (safe for ps) and env vars that // configure a one-shot git credential helper to supply the real credentials. -func credentialEnv(rawURL string) (safeURL string, env []string) { +func credentialEnv(rawURL string) (safeURL string, env []string, err error) { if rawURL == "" { - return "", nil + return "", nil, nil + } + if strings.ContainsAny(rawURL, "?#") { + return "", nil, errors.New("remote URL must not include query or fragment") } u, err := url.Parse(rawURL) - if err != nil || u.User == nil { - return rawURL, nil + if err != nil { + if remoteHasInlineCredentials(rawURL) { + return "", nil, errors.New("malformed remote URL") + } + if rawUserinfoCandidateHasPassword(rawURL) { + return "", nil, errors.New("malformed remote URL") + } + if strings.Contains(rawURL, "://") { + return "", nil, errors.New("malformed remote URL") + } + return rawURL, nil, nil + } + if u.RawQuery != "" || u.ForceQuery || u.Fragment != "" || strings.Contains(rawURL, "#") { + return "", nil, errors.New("remote URL must not include query or fragment") + } + if u.User == nil && strings.Contains(rawURL, "@") && (auth.HasInlineCredentials(rawURL) || malformedUserinfoInRemote(rawURL, u)) { + return "", nil, errors.New("malformed remote URL") + } + if u.User == nil { + return rawURL, nil, nil + } + if !isHTTPRemote(rawURL, u.Scheme) { + if strings.ToLower(u.Scheme) != "ssh" { + return "", nil, errors.New("remote URL includes unsupported inline credentials") + } + if _, hasPassword := u.User.Password(); hasPassword || auth.HasInlineCredentials(rawURL) { + return "", nil, errors.New("remote URL includes unsupported inline credentials") + } + return rawURL, nil, nil } username := u.User.Username() password, hasPassword := u.User.Password() if username == "" && !hasPassword { - return rawURL, nil + return rawURL, nil, nil } - // Build a credential helper that prints credentials to stdout. - // Uses printf to avoid shell quoting issues with single quotes in passwords. - var lines []string + credentialUsername := username + credentialPassword := password if hasPassword { - lines = append(lines, "username="+username, "password="+password) + credentialPassword = password } else if username != "" { // Token-as-username pattern (e.g., https://ghp_xxx@github.com) - lines = append(lines, "username="+username, "password="+username) + credentialPassword = username } - // Escape single quotes in the credential payload to prevent shell injection. - payload := strings.Join(lines, "\n") - payload = strings.ReplaceAll(payload, "'", "'\\''") - helper := fmt.Sprintf("!f() { printf '%%s\\n' '%s'; }; f", payload) + helper := "!f() { printf '%s\\n' \"username=$ARTIFACT_FS_GIT_USERNAME\" \"password=$ARTIFACT_FS_GIT_PASSWORD\"; }; f" u.User = nil return u.String(), []string{ "GIT_TERMINAL_PROMPT=0", - "GIT_CONFIG_COUNT=1", + "ARTIFACT_FS_GIT_USERNAME=" + credentialUsername, + "ARTIFACT_FS_GIT_PASSWORD=" + credentialPassword, + "GIT_CONFIG_COUNT=2", "GIT_CONFIG_KEY_0=credential.helper", - "GIT_CONFIG_VALUE_0=" + helper, + "GIT_CONFIG_VALUE_0=", + "GIT_CONFIG_KEY_1=credential.helper", + "GIT_CONFIG_VALUE_1=" + helper, + }, nil +} + +func isHTTPRemote(rawURL string, scheme string) bool { + switch strings.ToLower(scheme) { + case "http", "https": + return true + } + lower := strings.ToLower(strings.TrimSpace(rawURL)) + return strings.HasPrefix(lower, "http:/") || strings.HasPrefix(lower, "https:/") || + strings.HasPrefix(lower, "http//") || strings.HasPrefix(lower, "https//") || + strings.HasPrefix(lower, "http:") || strings.HasPrefix(lower, "https:") +} + +func isMalformedHTTPUserinfo(rawURL string, u *url.URL) bool { + if !isHTTPRemote(rawURL, u.Scheme) { + return false + } + if u.Host == "" { + return true + } + return strings.HasPrefix(u.Path, "/@") +} + +func remoteHasInlineCredentials(rawURL string) bool { + if strings.ContainsAny(rawURL, "?#") { + return true + } + if schemeLessUserinfoHasPassword(rawURL) { + return true + } + u, err := url.Parse(rawURL) + if err != nil { + return auth.HasInlineCredentials(rawURL) || rawUserinfoCandidateHasPassword(rawURL) + } + if u.User != nil { + _, hasPassword := u.User.Password() + return isHTTPRemote(rawURL, u.Scheme) || strings.ToLower(u.Scheme) != "ssh" || hasPassword || auth.HasInlineCredentials(rawURL) + } + return strings.Contains(rawURL, "@") && malformedUserinfoInRemote(rawURL, u) +} + +func malformedUserinfoInRemote(rawURL string, u *url.URL) bool { + if isHTTPRemote(rawURL, u.Scheme) { + return auth.HasInlineCredentials(rawURL) + } + if isMalformedHTTPUserinfo(rawURL, u) { + return true + } + if isHTTPRemote(rawURL, u.Scheme) && strings.Contains(u.Hostname(), ".") { + return false + } + return rawUserinfoCandidateHasPassword(rawURL) +} + +func rawUserinfoCandidateHasPassword(raw string) bool { + if isSCPStyleRemote(raw) { + return false + } + if schemeLessUserinfoHasPassword(raw) { + return true + } + prefix := raw + start := -1 + if i := strings.LastIndex(prefix, "://"); i >= 0 { + start = i + len("://") + } else if i := strings.Index(prefix, ":/"); i >= 0 { + start = i + len(":/") + } else if i := strings.Index(prefix, "//"); i >= 0 { + start = i + len("//") + } else if i := strings.Index(prefix, ":"); i >= 0 { + start = i + len(":") + } + if start < 0 || start >= len(raw) { + return false + } + endChars := "?#" + if strings.Contains(raw, "://") { + endChars = "/?#" + } + end := len(raw) + if relEnd := strings.IndexAny(raw[start:], endChars); relEnd >= 0 { + end = start + relEnd + } + at := strings.LastIndex(raw[start:end], "@") + if at < 0 { + return false + } + at += start + return strings.Contains(raw[start:at], ":") +} + +func schemeLessUserinfoHasPassword(raw string) bool { + if strings.Contains(raw, "://") { + return false + } + if isSCPStyleRemote(raw) { + return false + } + end := len(raw) + if relEnd := strings.IndexAny(raw, "/?#"); relEnd >= 0 { + end = relEnd + } + if end == 0 { + return false + } + prefix := raw[:end] + at := strings.LastIndex(prefix, "@") + colon := strings.Index(prefix, ":") + return colon >= 0 && (at > colon || strings.Contains(raw[end:], "@")) +} + +func isSCPStyleRemote(raw string) bool { + if strings.Contains(raw, "://") { + return false + } + end := len(raw) + if relEnd := strings.IndexAny(raw, "/?#"); relEnd >= 0 { + end = relEnd + } + prefix := raw[:end] + at := strings.Index(prefix, "@") + colon := strings.Index(prefix, ":") + return at > 0 && colon > at +} + +func nonInteractiveGitEnv() []string { + return []string{"GIT_TERMINAL_PROMPT=0", "GIT_SSH_COMMAND=" + sshBatchModeCommand(os.Getenv("GIT_SSH_COMMAND"))} +} + +func sshBatchModeCommand(existing string) string { + existing = strings.TrimSpace(existing) + if existing == "" { + return "ssh -o BatchMode=yes" + } + tokens := splitShellFields(existing) + filtered := make([]string, 0, len(tokens)+2) + for i := 0; i < len(tokens); i++ { + tok := tokens[i] + lower := strings.ToLower(tok) + if lower == "-o" && i+1 < len(tokens) && isBatchModeOption(tokens[i+1]) { + i++ + continue + } + if strings.HasPrefix(lower, "-obatchmode=") { + continue + } + filtered = append(filtered, tok) + } + if len(filtered) == 0 { + filtered = append(filtered, "ssh") + } + filtered = append(filtered, "-o", "BatchMode=yes") + for i, tok := range filtered { + filtered[i] = shellQuote(tok) + } + return strings.Join(filtered, " ") +} + +func splitShellFields(s string) []string { + var fields []string + var b strings.Builder + var quote rune + escaped := false + for _, r := range s { + if escaped { + if r == '$' { + b.WriteString(`\$`) + } else { + b.WriteRune(r) + } + escaped = false + continue + } + switch { + case quote != 0: + if r == quote { + quote = 0 + } else if r == '$' && quote == '\'' { + b.WriteString(`\$`) + } else if r == '\\' && quote == '"' { + escaped = true + } else { + b.WriteRune(r) + } + case r == '\'' || r == '"': + quote = r + case r == '\\': + escaped = true + case r == ' ' || r == '\t' || r == '\n': + if b.Len() > 0 { + fields = append(fields, b.String()) + b.Reset() + } + default: + b.WriteRune(r) + } + } + if escaped { + b.WriteRune('\\') + } + if b.Len() > 0 { + fields = append(fields, b.String()) + } + return fields +} + +func isBatchModeOption(opt string) bool { + parts := strings.Fields(strings.ToLower(strings.TrimSpace(opt))) + if len(parts) == 0 { + return false + } + return parts[0] == "batchmode" || strings.HasPrefix(parts[0], "batchmode=") +} + +func shellQuote(s string) string { + if s == "" { + return "''" + } + if strings.Contains(s, "$") { + return doubleQuote(s) + } + if isShellSafe(s) { + return s + } + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + +func doubleQuote(s string) string { + var b strings.Builder + b.WriteByte('"') + for i := 0; i < len(s); i++ { + if s[i] == '\\' && i+1 < len(s) && s[i+1] == '$' { + b.WriteString(`\$`) + i++ + continue + } + switch s[i] { + case '\\', '"', '`': + b.WriteByte('\\') + } + b.WriteByte(s[i]) + } + b.WriteByte('"') + return b.String() +} + +func isShellSafe(s string) bool { + for _, r := range s { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + continue + } + if strings.ContainsRune("@%_+=:,./-~$", r) { + continue + } + return false + } + return true +} + +func fetchRefTarget(repo model.RepoConfig, ref string) (fetchRefInfo, error) { + ref = strings.TrimSpace(ref) + if ref == "" { + ref = strings.TrimSpace(repo.Branch) + } + if ref == "" { + return fetchRefInfo{}, errors.New("fetch ref is required") + } + if branch := branchName(ref); branch != "" { + return fetchRefInfo{ + sourceRef: "refs/heads/" + branch, + remoteRef: "refs/remotes/origin/" + branch, + branch: branch, + }, nil + } + if strings.HasPrefix(ref, "refs/") { + return fetchRefInfo{sourceRef: ref, remoteRef: fetchedFullRefRemoteTrackingRef}, nil + } + return fetchRefInfo{}, errors.New("fetch ref is required") +} + +func branchName(ref string) string { + ref = strings.TrimSpace(ref) + if strings.HasPrefix(ref, "refs/heads/") { + return strings.TrimPrefix(ref, "refs/heads/") + } + if strings.HasPrefix(ref, "refs/remotes/origin/") { + return strings.TrimPrefix(ref, "refs/remotes/origin/") + } + if strings.HasPrefix(ref, "origin/") { + return strings.TrimPrefix(ref, "origin/") + } + if strings.HasPrefix(ref, "refs/") { + return "" } + return ref } func rootNode(repoID model.RepoID) model.BaseNode { diff --git a/internal/gitstore/gitstore_benchmark_test.go b/internal/gitstore/gitstore_benchmark_test.go new file mode 100644 index 0000000..933f1df --- /dev/null +++ b/internal/gitstore/gitstore_benchmark_test.go @@ -0,0 +1,73 @@ +package gitstore + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/cloudflare/artifact-fs/internal/model" +) + +func BenchmarkBuildTreeIndexSynthetic(b *testing.B) { + const objects = 4096 + workDir, gitDir := createBuildTreeBenchmarkRepo(b, objects) + b.Cleanup(func() { _ = os.RemoveAll(workDir) }) + + cfg := model.RepoConfig{ID: "repo", Name: "repo", GitDir: gitDir} + store := New(nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + head, _, err := store.ResolveHEAD(ctx, cfg) + if err != nil { + b.Fatalf("ResolveHEAD: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + nodes, err := store.BuildTreeIndex(context.Background(), cfg, head) + if err != nil { + b.Fatalf("BuildTreeIndex: %v", err) + } + b.ReportMetric(float64(len(nodes)), "nodes/op") + } +} + +func createBuildTreeBenchmarkRepo(b *testing.B, objects int) (workDir string, gitDir string) { + b.Helper() + workDir, err := os.MkdirTemp("", "artifact-fs-buildtree-bench-") + if err != nil { + b.Fatal(err) + } + runBuildTreeBenchmarkGit(b, workDir, "init") + runBuildTreeBenchmarkGit(b, workDir, "config", "user.name", "BuildTree Bench") + runBuildTreeBenchmarkGit(b, workDir, "config", "user.email", "buildtree-bench@example.com") + for i := range objects { + dir := filepath.Join(workDir, fmt.Sprintf("dir-%02d", i%16)) + if err := os.MkdirAll(dir, 0o755); err != nil { + b.Fatal(err) + } + path := filepath.Join(dir, fmt.Sprintf("file-%04d.txt", i)) + data := []byte(fmt.Sprintf("blob payload %04d\n", i)) + if err := os.WriteFile(path, data, 0o644); err != nil { + b.Fatal(err) + } + } + runBuildTreeBenchmarkGit(b, workDir, "add", ".") + runBuildTreeBenchmarkGit(b, workDir, "commit", "-m", "add benchmark blobs") + return workDir, filepath.Join(workDir, ".git") +} + +func runBuildTreeBenchmarkGit(b *testing.B, dir string, args ...string) { + b.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + b.Fatalf("git %v: %v\n%s", args, err, out) + } +} diff --git a/internal/gitstore/gitstore_test.go b/internal/gitstore/gitstore_test.go index b6e7d22..ff55a0d 100644 --- a/internal/gitstore/gitstore_test.go +++ b/internal/gitstore/gitstore_test.go @@ -89,6 +89,46 @@ func TestBlobToCacheBinarySafe(t *testing.T) { } } +func TestBlobToCacheHonorsCanceledContext(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + repo := filepath.Join(tmp, "repo") + run(t, "git", "init", repo) + os.WriteFile(filepath.Join(repo, "file.txt"), []byte("line\n"), 0o644) + run(t, "git", "-C", repo, "add", "file.txt") + run(t, "git", "-C", repo, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + + cfg := model.RepoConfig{ID: "x", GitDir: filepath.Join(repo, ".git"), BlobCacheDir: filepath.Join(tmp, "cache")} + store := New(nil) + oid, _, err := store.ResolveHEAD(context.Background(), cfg) + if err != nil { + t.Fatal(err) + } + nodes, err := store.BuildTreeIndex(context.Background(), cfg, oid) + if err != nil { + t.Fatal(err) + } + var blobOID string + for _, n := range nodes { + if n.Path == "file.txt" { + blobOID = n.ObjectOID + } + } + if blobOID == "" { + t.Fatal("no blob OID found") + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + dst := filepath.Join(tmp, "cache", blobOID) + _, err = store.BlobToCache(ctx, cfg, blobOID, dst) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context.Canceled", err) + } + if _, err := os.Stat(dst); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("cache file should not be written after cancellation: %v", err) + } +} + func TestReadBlobRespectsMaxBytes(t *testing.T) { t.Parallel() tmp := t.TempDir() @@ -214,38 +254,439 @@ func TestReadTreeHEAD(t *testing.T) { } } -func TestCredentialEnvEscapesSingleQuotes(t *testing.T) { +func TestFetchRefNonInteractiveAndPrepareFetchedBranch(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + + run(t, "git", "init", "--bare", bare) + run(t, "git", "clone", bare, work) + run(t, "git", "-C", work, "checkout", "-b", "master") + os.WriteFile(filepath.Join(work, "README.md"), []byte("hello\n"), 0o644) + run(t, "git", "-C", work, "add", "README.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + run(t, "git", "-C", work, "push", "origin", "master") + + run(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + run(t, "git", "-C", preparedWorktree, "remote", "add", "origin", "file://"+bare) + + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: preparedGitDir, Branch: "master"} + store := New(nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := store.ValidatePreparedGitDir(ctx, cfg); err != nil { + t.Fatalf("ValidatePreparedGitDir: %v", err) + } + if err := store.FetchRefNonInteractive(ctx, cfg, "master"); err != nil { + t.Fatalf("FetchRefNonInteractive: %v", err) + } + if err := store.PrepareFetchedBranch(ctx, cfg, "master"); err != nil { + t.Fatalf("PrepareFetchedBranch: %v", err) + } + oid, ref, err := store.ResolveHEAD(ctx, cfg) + if err != nil { + t.Fatalf("ResolveHEAD: %v", err) + } + if ref != "master" { + t.Fatalf("ref = %q, want master", ref) + } + nodes, err := store.BuildTreeIndex(ctx, cfg, oid) + if err != nil { + t.Fatalf("BuildTreeIndex: %v", err) + } + found := false + for _, n := range nodes { + if n.Path == "README.md" { + found = true + } + } + if !found { + t.Fatal("README.md not found in prepared tree") + } +} + +func TestPrepareFetchedBranchRefusesPreparedGitDirRewind(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + + run(t, "git", "init", "--bare", bare) + run(t, "git", "clone", bare, work) + run(t, "git", "-C", work, "checkout", "-b", "master") + os.WriteFile(filepath.Join(work, "README.md"), []byte("origin\n"), 0o644) + run(t, "git", "-C", work, "add", "README.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "origin") + run(t, "git", "-C", work, "push", "origin", "master") + + run(t, "git", "clone", bare, preparedWorktree) + run(t, "git", "-C", preparedWorktree, "checkout", "master") + localPath := filepath.Join(preparedWorktree, "LOCAL.md") + os.WriteFile(localPath, []byte("local\n"), 0o644) + run(t, "git", "-C", preparedWorktree, "add", "LOCAL.md") + run(t, "git", "-C", preparedWorktree, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "local") + if err := os.Rename(filepath.Join(preparedWorktree, ".git"), preparedGitDir); err != nil { + t.Fatal(err) + } + + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: preparedGitDir, Branch: "master", PreparedGitDir: true} + store := New(nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + localOID, err := runGit(ctx, preparedGitDir, "rev-parse", "refs/heads/master") + if err != nil { + t.Fatal(err) + } + localOID = strings.TrimSpace(localOID) + + if err := store.FetchRefNonInteractive(ctx, cfg, "master"); err != nil { + t.Fatalf("FetchRefNonInteractive: %v", err) + } + err = store.PrepareFetchedBranch(ctx, cfg, "master") + if err == nil { + t.Fatal("expected non-fast-forward prepared branch update to fail") + } + if strings.Contains(err.Error(), localOID) { + t.Fatalf("error leaked commit details: %v", err) + } + afterOID, err := runGit(ctx, preparedGitDir, "rev-parse", "refs/heads/master") + if err != nil { + t.Fatal(err) + } + afterOID = strings.TrimSpace(afterOID) + if afterOID != localOID { + t.Fatalf("prepared branch moved to %s, want %s", afterOID, localOID) + } +} + +func TestPrepareExistingCloneNonInteractiveUpdatesBranch(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + gitDir := filepath.Join(tmp, "repo.git") + + run(t, "git", "init", "--bare", bare) + run(t, "git", "clone", bare, work) + run(t, "git", "-C", work, "checkout", "-b", "master") + if err := os.WriteFile(filepath.Join(work, "README.md"), []byte("master\n"), 0o644); err != nil { + t.Fatal(err) + } + run(t, "git", "-C", work, "add", "README.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "master") + run(t, "git", "-C", work, "push", "origin", "master") + run(t, "git", "-C", work, "checkout", "-b", "dev") + os.WriteFile(filepath.Join(work, "DEV.md"), []byte("dev\n"), 0o644) + run(t, "git", "-C", work, "add", "DEV.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "dev") + run(t, "git", "-C", work, "push", "origin", "dev") + + store := New(nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: gitDir, RemoteURL: "file://" + bare, Branch: "master", FetchRef: "master"} + if err := store.CloneBloblessNonInteractive(ctx, cfg); err != nil { + t.Fatalf("CloneBloblessNonInteractive: %v", err) + } + cfg.Branch = "dev" + cfg.FetchRef = "dev" + if err := store.PrepareExistingCloneNonInteractive(ctx, cfg); err != nil { + t.Fatalf("PrepareExistingCloneNonInteractive: %v", err) + } + oid, ref, err := store.ResolveHEAD(ctx, cfg) + if err != nil { + t.Fatalf("ResolveHEAD: %v", err) + } + if ref != "dev" { + t.Fatalf("ref = %q, want dev", ref) + } + nodes, err := store.BuildTreeIndex(ctx, cfg, oid) + if err != nil { + t.Fatalf("BuildTreeIndex: %v", err) + } + found := false + for _, n := range nodes { + if n.Path == "DEV.md" { + found = true + } + } + if !found { + t.Fatal("DEV.md not found after existing clone prepare") + } +} + +func TestCloneAndFetchRefSkipTags(t *testing.T) { + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + gitDir := filepath.Join(tmp, "repo.git") + + run(t, "git", "init", "--bare", bare) + run(t, "git", "clone", bare, work) + run(t, "git", "-C", work, "checkout", "-b", "master") + if err := os.WriteFile(filepath.Join(work, "README.md"), []byte("master\n"), 0o644); err != nil { + t.Fatal(err) + } + run(t, "git", "-C", work, "add", "README.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "master") + run(t, "git", "-C", work, "push", "origin", "master") + + realGit, err := exec.LookPath("git") + if err != nil { + t.Fatal(err) + } + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + logPath := filepath.Join(tmp, "git.log") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$GIT_COMMAND_LOG\"\nexec \"$REAL_GIT\" \"$@\"\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_COMMAND_LOG", logPath) + t.Setenv("REAL_GIT", realGit) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + store := New(nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: gitDir, RemoteURL: "file://" + bare, Branch: "master", FetchRef: "master"} + if err := store.CloneBloblessNonInteractive(ctx, cfg); err != nil { + t.Fatalf("CloneBloblessNonInteractive: %v", err) + } + if err := store.FetchRefNonInteractive(ctx, cfg, cfg.FetchRef); err != nil { + t.Fatalf("FetchRefNonInteractive: %v", err) + } + if err := store.Fetch(ctx, cfg); err != nil { + t.Fatalf("Fetch: %v", err) + } + + logData, err := os.ReadFile(logPath) + if err != nil { + t.Fatal(err) + } + logText := string(logData) + if !strings.Contains(logText, "clone --filter=blob:none --no-checkout --single-branch --no-tags --branch master") { + t.Fatalf("clone did not include --no-tags; git log:\n%s", logText) + } + if !strings.Contains(logText, "fetch --filter=blob:none --no-tags origin +refs/heads/master:refs/remotes/origin/master") { + t.Fatalf("fetch did not include --no-tags; git log:\n%s", logText) + } + if !strings.Contains(logText, "fetch --no-tags origin") { + t.Fatalf("refresh fetch did not include --no-tags; git log:\n%s", logText) + } +} + +func TestPrepareExistingCloneRejectsCredentialedRemoteBeforeSetURL(t *testing.T) { + tmp := t.TempDir() + gitDir := filepath.Join(tmp, "repo.git") + worktree := filepath.Join(tmp, "worktree") + run(t, "git", "init", "--separate-git-dir", gitDir, "--initial-branch", "master", worktree) + run(t, "git", "-C", worktree, "remote", "add", "origin", "https://github.com/org/repo.git") + + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + marker := filepath.Join(tmp, "set-url-invoked") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\ncase \"$*\" in *\"remote set-url\"*) : > \"$GIT_SET_URL_MARKER\";; esac\nexec /usr/bin/git \"$@\"\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_SET_URL_MARKER", marker) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + store := New(nil) + for _, remote := range []string{ + "ssh:/git:secret@github.com/org/repo.git", + "ssh//git:secret@github.com/org/repo.git", + "ssh:/git:pa/ss@github.com/org/repo.git", + "alice:ghp_secret@github.com:org/repo.git", + } { + t.Run(remote, func(t *testing.T) { + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: gitDir, RemoteURL: remote, Branch: "master", FetchRef: "master"} + err := store.PrepareExistingCloneNonInteractive(context.Background(), cfg) + if err == nil { + t.Fatal("expected credentialed remote rejection") + } + if strings.Contains(err.Error(), "secret") { + t.Fatalf("error leaked credential: %v", err) + } + if _, statErr := os.Stat(marker); !errors.Is(statErr, os.ErrNotExist) { + t.Fatal("git remote set-url was invoked before rejecting credentialed remote") + } + }) + } +} + +func TestValidatePreparedGitDirRejectsCredentialedOrigin(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + run(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + run(t, "git", "-C", preparedWorktree, "remote", "add", "origin", "https://ghp_secret@github.com/org/repo.git") + + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: preparedGitDir, Branch: "master"} + store := New(nil) + err := store.ValidatePreparedGitDir(context.Background(), cfg) + if err == nil { + t.Fatal("expected credentialed origin rejection") + } + if strings.Contains(err.Error(), "ghp_secret") { + t.Fatalf("error leaked origin credential: %v", err) + } +} + +func TestValidatePreparedGitDirRejectsMalformedCredentialedOrigin(t *testing.T) { + for _, remote := range []string{ + "ssh:/git:secret@github.com/org/repo.git", + "alice:ghp_secret@github.com:org/repo.git", + } { + t.Run(remote, func(t *testing.T) { + tmp := t.TempDir() + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + run(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + run(t, "git", "-C", preparedWorktree, "remote", "add", "origin", remote) + + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: preparedGitDir, Branch: "master"} + store := New(nil) + if err := store.ValidatePreparedGitDir(context.Background(), cfg); err == nil { + t.Fatal("expected malformed credentialed origin rejection") + } + }) + } +} + +func TestValidatePreparedGitDirAllowsAtInHTTPSOriginPath(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + run(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + run(t, "git", "-C", preparedWorktree, "remote", "add", "origin", "https://git.example.com/team/repo@2026.git") + + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: preparedGitDir, Branch: "master"} + store := New(nil) + if err := store.ValidatePreparedGitDir(context.Background(), cfg); err != nil { + t.Fatalf("ValidatePreparedGitDir: %v", err) + } +} + +func TestFetchRefNonInteractiveFullRefPreparesDetachedHEAD(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + bare := filepath.Join(tmp, "origin.git") + work := filepath.Join(tmp, "work") + preparedGitDir := filepath.Join(tmp, "prepared.git") + preparedWorktree := filepath.Join(tmp, "prepared") + + run(t, "git", "init", "--bare", bare) + run(t, "git", "clone", bare, work) + run(t, "git", "-C", work, "checkout", "-b", "master") + os.WriteFile(filepath.Join(work, "README.md"), []byte("hello\n"), 0o644) + run(t, "git", "-C", work, "add", "README.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + run(t, "git", "-C", work, "push", "origin", "master") + run(t, "git", "-C", work, "checkout", "-b", "pull-request") + os.WriteFile(filepath.Join(work, "PR.md"), []byte("pull request\n"), 0o644) + run(t, "git", "-C", work, "add", "PR.md") + run(t, "git", "-C", work, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "pr") + run(t, "git", "-C", work, "push", "origin", "HEAD:refs/pull/10/head") + + run(t, "git", "init", "--separate-git-dir", preparedGitDir, "--initial-branch", "master", preparedWorktree) + run(t, "git", "-C", preparedWorktree, "remote", "add", "origin", "file://"+bare) + + cfg := model.RepoConfig{ID: "x", Name: "x", GitDir: preparedGitDir, Branch: "master"} + store := New(nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := store.FetchRefNonInteractive(ctx, cfg, "refs/pull/10/head"); err != nil { + t.Fatalf("FetchRefNonInteractive: %v", err) + } + if _, err := runGit(ctx, preparedGitDir, "rev-parse", "--verify", fetchedFullRefRemoteTrackingRef+"^{commit}"); err != nil { + t.Fatalf("expected fetched full ref at safe remote-tracking ref: %v", err) + } + if err := store.PrepareFetchedBranch(ctx, cfg, "refs/pull/10/head"); err != nil { + t.Fatalf("PrepareFetchedBranch: %v", err) + } + oid, ref, err := store.ResolveHEAD(ctx, cfg) + if err != nil { + t.Fatalf("ResolveHEAD: %v", err) + } + if ref != "DETACHED" { + t.Fatalf("ref = %q, want DETACHED", ref) + } + nodes, err := store.BuildTreeIndex(ctx, cfg, oid) + if err != nil { + t.Fatalf("BuildTreeIndex: %v", err) + } + found := false + for _, n := range nodes { + if n.Path == "PR.md" { + found = true + } + } + if !found { + t.Fatal("PR.md not found in prepared tree") + } +} + +func TestCredentialEnvKeepsSecretsOutOfHelperCommand(t *testing.T) { t.Parallel() - // Password with a single quote should be escaped - safeURL, env := credentialEnv("https://user:p@ss'word@github.com/org/repo.git") + safeURL, env, err := credentialEnv("https://user:p@ss'word@github.com/org/repo.git") + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } if safeURL == "" { t.Fatal("expected non-empty safe URL") } if strings.Contains(safeURL, "p@ss") { t.Fatalf("safe URL should not contain password: %s", safeURL) } - // The credential helper env var should contain escaped quote - found := false + foundHelper := false + foundReset := false + foundPassword := false for _, e := range env { - if val, ok := strings.CutPrefix(e, "GIT_CONFIG_VALUE_0="); ok { - found = true + if e == "GIT_CONFIG_VALUE_0=" { + foundReset = true + } + if val, ok := strings.CutPrefix(e, "GIT_CONFIG_VALUE_1="); ok { + foundHelper = true if strings.Contains(val, "p@ss'word") { - t.Fatalf("unescaped password in helper: %s", val) - } - // Should contain the escaped form - if !strings.Contains(val, `'\''`) { - t.Fatalf("expected escaped single quote in helper, got: %s", val) + t.Fatalf("password leaked in helper command: %s", val) } } + if e == "ARTIFACT_FS_GIT_PASSWORD=p@ss'word" { + foundPassword = true + } } - if !found { - t.Fatal("expected GIT_CONFIG_VALUE_0 in env") + if !foundReset { + t.Fatal("expected empty credential.helper reset") + } + if !foundHelper { + t.Fatal("expected GIT_CONFIG_VALUE_1 in env") + } + if !foundPassword { + t.Fatalf("expected password env var, got %v", env) } } func TestCredentialEnvNoCredentials(t *testing.T) { t.Parallel() - safeURL, env := credentialEnv("https://github.com/org/repo.git") + safeURL, env, err := credentialEnv("https://github.com/org/repo.git") + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } if safeURL != "https://github.com/org/repo.git" { t.Fatalf("expected unchanged URL, got %s", safeURL) } @@ -254,15 +695,438 @@ func TestCredentialEnvNoCredentials(t *testing.T) { } } +func TestCredentialEnvAllowsFileURLPathWithAtSign(t *testing.T) { + t.Parallel() + const remote = "file:///tmp/repo@2026.git" + safeURL, env, err := credentialEnv(remote) + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } + if safeURL != remote { + t.Fatalf("safe URL = %q, want %q", safeURL, remote) + } + if len(env) != 0 { + t.Fatalf("expected no env vars, got %v", env) + } +} + +func TestCredentialEnvAllowsSCPStyleRootPathWithAtSign(t *testing.T) { + t.Parallel() + const remote = "git@example.com:repo:v1@2026.git" + safeURL, env, err := credentialEnv(remote) + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } + if safeURL != remote { + t.Fatalf("safe URL = %q, want %q", safeURL, remote) + } + if len(env) != 0 { + t.Fatalf("expected no env vars, got %v", env) + } +} + +func TestCredentialEnvRejectsQueryAndFragment(t *testing.T) { + t.Parallel() + for _, raw := range []string{ + "https://github.com/org/repo.git?access_token=secret", + "https://github.com/org/repo.git#access_token=secret", + "https://github.com/org/repo.git#", + "git@github.com:org/repo.git?access_token=secret", + "git@github.com:org/repo.git#access_token=secret", + } { + t.Run(raw, func(t *testing.T) { + if _, _, err := credentialEnv(raw); err == nil { + t.Fatal("expected query or fragment rejection") + } + }) + } +} + func TestCredentialEnvTokenAsUsername(t *testing.T) { t.Parallel() - safeURL, env := credentialEnv("https://ghp_abc123@github.com/org/repo.git") + safeURL, env, err := credentialEnv("https://ghp_abc123@github.com/org/repo.git") + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } if strings.Contains(safeURL, "ghp_abc123") { t.Fatalf("token should be stripped from safe URL: %s", safeURL) } if len(env) == 0 { t.Fatal("expected credential helper env vars") } + for _, e := range env { + if strings.HasPrefix(e, "GIT_CONFIG_VALUE_1=") && strings.Contains(e, "ghp_abc123") { + t.Fatalf("credential helper command leaked token: %s", e) + } + } +} + +func TestCredentialEnvPreservesSSHUsername(t *testing.T) { + t.Parallel() + safeURL, env, err := credentialEnv("ssh://git@github.com/org/repo.git") + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } + if safeURL != "ssh://git@github.com/org/repo.git" { + t.Fatalf("safe URL = %q, want SSH username preserved", safeURL) + } + if len(env) != 0 { + t.Fatalf("expected no credential helper env for SSH username, got %v", env) + } +} + +func TestCredentialEnvRejectsGitProtocolUsername(t *testing.T) { + t.Parallel() + if _, _, err := credentialEnv("git://ghp_secret@github.com/org/repo.git"); err == nil { + t.Fatal("expected git protocol username rejection") + } +} + +func TestCredentialEnvRejectsSSHTokenUsername(t *testing.T) { + t.Parallel() + if _, _, err := credentialEnv("ssh://ghp_abcdefghijklmnopqrstuvwxyz@github.com/org/repo.git"); err == nil { + t.Fatal("expected SSH token username rejection") + } +} + +func TestCredentialEnvRejectsSSHPassword(t *testing.T) { + t.Parallel() + if _, _, err := credentialEnv("ssh://git:secret@github.com/org/repo.git"); err == nil { + t.Fatal("expected SSH password rejection") + } +} + +func TestCredentialEnvRejectsMalformedSSHPassword(t *testing.T) { + t.Parallel() + for _, raw := range []string{ + "ssh:/git:secret@github.com/org/repo.git", + "ssh:/git:bad%zz@github.com/org/repo.git", + "alice:ghp_secret@github.com:org/repo.git", + "x-token-auth:secret@bitbucket.org/org/repo.git", + "https://ghp_secret/part@example.com/org/repo.git", + "https://ghp_secret/part@example.com", + } { + t.Run(raw, func(t *testing.T) { + if _, _, err := credentialEnv(raw); err == nil { + t.Fatal("expected malformed SSH password rejection") + } + }) + } +} + +func TestCloneBloblessRejectsMalformedCredentialURLBeforeGit(t *testing.T) { + tmp := t.TempDir() + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + marker := filepath.Join(tmp, "git-invoked") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\n: > \"$GIT_INVOKED_MARKER\"\nexit 1\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_INVOKED_MARKER", marker) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + cfg := model.RepoConfig{ + GitDir: filepath.Join(tmp, "repo.git"), + RemoteURL: "https://user:bad%zz@example.com/org/repo.git", + Branch: "main", + } + store := New(nil) + err := store.CloneBlobless(context.Background(), cfg) + if err == nil { + t.Fatal("expected malformed remote URL error") + } + if strings.Contains(err.Error(), "bad%zz") || strings.Contains(err.Error(), "user") { + t.Fatalf("error leaked credential URL: %v", err) + } + if _, statErr := os.Stat(marker); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("git was invoked before rejecting malformed URL") + } +} + +func TestCloneBloblessRejectsMalformedHTTPSUserinfoBeforeGit(t *testing.T) { + tmp := t.TempDir() + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + marker := filepath.Join(tmp, "git-invoked") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\n: > \"$GIT_INVOKED_MARKER\"\nexit 1\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_INVOKED_MARKER", marker) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + cfg := model.RepoConfig{ + GitDir: filepath.Join(tmp, "repo.git"), + RemoteURL: "https:/user:ghp_secret@github.com/org/repo.git", + Branch: "main", + } + store := New(nil) + err := store.CloneBlobless(context.Background(), cfg) + if err == nil { + t.Fatal("expected malformed remote URL error") + } + if strings.Contains(err.Error(), "ghp_secret") || strings.Contains(err.Error(), "user") { + t.Fatalf("error leaked credential URL: %v", err) + } + if _, statErr := os.Stat(marker); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("git was invoked before rejecting malformed URL") + } +} + +func TestCloneBloblessRejectsMalformedHTTPParseErrorBeforeGit(t *testing.T) { + tmp := t.TempDir() + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + marker := filepath.Join(tmp, "git-invoked") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\n: > \"$GIT_INVOKED_MARKER\"\nexit 1\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_INVOKED_MARKER", marker) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + cfg := model.RepoConfig{ + GitDir: filepath.Join(tmp, "repo.git"), + RemoteURL: "https//ghp_secret%zz@github.com/org/repo.git", + Branch: "main", + } + store := New(nil) + err := store.CloneBlobless(context.Background(), cfg) + if err == nil { + t.Fatal("expected malformed remote URL error") + } + if strings.Contains(err.Error(), "ghp_secret") || strings.Contains(err.Error(), "%zz") { + t.Fatalf("error leaked credential URL: %v", err) + } + if _, statErr := os.Stat(marker); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("git was invoked before rejecting malformed URL") + } +} + +func TestCloneBloblessRejectsMalformedGitStyleCredentialBeforeGit(t *testing.T) { + tmp := t.TempDir() + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + marker := filepath.Join(tmp, "git-invoked") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\n: > \"$GIT_INVOKED_MARKER\"\nexit 1\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_INVOKED_MARKER", marker) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + cfg := model.RepoConfig{ + GitDir: filepath.Join(tmp, "repo.git"), + RemoteURL: "git:secret@github.com:org/repo.git", + Branch: "main", + } + store := New(nil) + err := store.CloneBlobless(context.Background(), cfg) + if err == nil { + t.Fatal("expected malformed remote URL error") + } + if strings.Contains(err.Error(), "secret") { + t.Fatalf("error leaked credential URL: %v", err) + } + if _, statErr := os.Stat(marker); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("git was invoked before rejecting malformed URL") + } +} + +func TestCloneBloblessRejectsPathSplitHTTPCredentialsBeforeGit(t *testing.T) { + tmp := t.TempDir() + bin := filepath.Join(tmp, "bin") + if err := os.Mkdir(bin, 0o755); err != nil { + t.Fatal(err) + } + marker := filepath.Join(tmp, "git-invoked") + fakeGit := filepath.Join(bin, "git") + if err := os.WriteFile(fakeGit, []byte("#!/bin/sh\n: > \"$GIT_INVOKED_MARKER\"\nexit 1\n"), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("GIT_INVOKED_MARKER", marker) + t.Setenv("PATH", bin+string(os.PathListSeparator)+os.Getenv("PATH")) + + cfg := model.RepoConfig{ + GitDir: filepath.Join(tmp, "repo.git"), + RemoteURL: "https://user:123/ss@example.com/org/repo.git", + Branch: "main", + } + store := New(nil) + err := store.CloneBlobless(context.Background(), cfg) + if err == nil { + t.Fatal("expected malformed remote URL error") + } + if strings.Contains(err.Error(), "123") || strings.Contains(err.Error(), "ss") { + t.Fatalf("error leaked credential URL: %v", err) + } + if _, statErr := os.Stat(marker); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("git was invoked before rejecting malformed URL") + } +} + +func TestCredentialEnvRejectsHTTPSLikeUserinfoTypos(t *testing.T) { + t.Parallel() + for _, raw := range []string{ + "https:/user:ghp_secret@github.com/org/repo.git", + "https//ghp_secret@github.com/org/repo.git", + } { + t.Run(raw, func(t *testing.T) { + if _, _, err := credentialEnv(raw); err == nil { + t.Fatal("expected malformed HTTP-like remote rejection") + } + }) + } +} + +func TestCredentialEnvAllowsAtInHTTPSPath(t *testing.T) { + t.Parallel() + safeURL, env, err := credentialEnv("https://git.example.com/team/repo:v1@2026.git") + if err != nil { + t.Fatalf("credentialEnv: %v", err) + } + if safeURL != "https://git.example.com/team/repo:v1@2026.git" { + t.Fatalf("safe URL = %q", safeURL) + } + if len(env) != 0 { + t.Fatalf("expected no credential helper env, got %v", env) + } +} + +func TestNonInteractiveGitEnvForcesSSHBatchMode(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", "ssh -o BatchMode=no -i /secrets/deploy_key -o IdentitiesOnly=yes") + env := nonInteractiveGitEnv() + for _, e := range env { + if strings.HasPrefix(e, "GIT_SSH_COMMAND=") { + if !strings.Contains(e, "-i /secrets/deploy_key") { + t.Fatalf("expected existing identity option to be preserved, got %q", e) + } + if strings.Contains(e, "BatchMode=no") { + t.Fatalf("expected existing BatchMode option to be replaced, got %q", e) + } + if strings.Contains(e, "BatchMode=yes") { + return + } + break + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) +} + +func TestNonInteractiveGitEnvDefaultSSHBatchMode(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", "") + env := nonInteractiveGitEnv() + for _, e := range env { + if e == "GIT_SSH_COMMAND=ssh -o BatchMode=yes" { + return + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) +} + +func TestNonInteractiveGitEnvStripsQuotedBatchMode(t *testing.T) { + for _, command := range []string{ + `ssh -o "BatchMode=no" -i /secrets/deploy_key`, + `ssh -o BatchMode="no" -i /secrets/deploy_key`, + `ssh -o "BatchMode"=no -i /secrets/deploy_key`, + `ssh -o 'BatchMode no' -i /secrets/deploy_key`, + `ssh '-o' 'BatchMode=no' -i /secrets/deploy_key`, + `ssh -oBatchMode="no" -i /secrets/deploy_key`, + } { + t.Run(command, func(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", command) + env := nonInteractiveGitEnv() + for _, e := range env { + if strings.HasPrefix(e, "GIT_SSH_COMMAND=") { + if strings.Contains(e, "BatchMode=no") || strings.Contains(e, `BatchMode="no"`) { + t.Fatalf("expected quoted BatchMode option to be replaced, got %q", e) + } + if !strings.Contains(e, "-i /secrets/deploy_key") || !strings.Contains(e, "BatchMode=yes") { + t.Fatalf("expected identity and BatchMode=yes, got %q", e) + } + return + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) + }) + } +} + +func TestNonInteractiveGitEnvPreservesProxyCommand(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", "ssh -o ProxyCommand='ssh -o BatchMode=no bastion' -i /secrets/deploy_key") + env := nonInteractiveGitEnv() + for _, e := range env { + if strings.HasPrefix(e, "GIT_SSH_COMMAND=") { + if !strings.Contains(e, "ProxyCommand=ssh -o BatchMode=no bastion") { + t.Fatalf("expected ProxyCommand to be preserved, got %q", e) + } + if !strings.Contains(e, "BatchMode=yes") { + t.Fatalf("expected top-level BatchMode=yes, got %q", e) + } + return + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) +} + +func TestNonInteractiveGitEnvQuotesShellMetacharacters(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", `ssh -i /tmp/key\ prod -o UserKnownHostsFile=/tmp/known\ hosts`) + env := nonInteractiveGitEnv() + for _, e := range env { + if strings.HasPrefix(e, "GIT_SSH_COMMAND=") { + if !strings.Contains(e, "'/tmp/key prod'") || !strings.Contains(e, "'UserKnownHostsFile=/tmp/known hosts'") { + t.Fatalf("expected escaped shell paths to be quoted, got %q", e) + } + if !strings.Contains(e, "BatchMode=yes") { + t.Fatalf("expected BatchMode=yes, got %q", e) + } + return + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) +} + +func TestNonInteractiveGitEnvPreservesShellExpansion(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", `ssh -i "$HOME/.ssh/deploy key"`) + env := nonInteractiveGitEnv() + for _, e := range env { + if strings.HasPrefix(e, "GIT_SSH_COMMAND=") { + if !strings.Contains(e, "$HOME/.ssh/deploy key") { + t.Fatalf("expected HOME expansion to be preserved, got %q", e) + } + if strings.Contains(e, "'$HOME/.ssh/deploy key'") { + t.Fatalf("expected HOME expansion not to be single-quoted, got %q", e) + } + return + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) +} + +func TestNonInteractiveGitEnvPreservesEscapedDollar(t *testing.T) { + t.Setenv("GIT_SSH_COMMAND", `ssh -i '/tmp/key$prod dir'`) + env := nonInteractiveGitEnv() + for _, e := range env { + if strings.HasPrefix(e, "GIT_SSH_COMMAND=") { + if !strings.Contains(e, `/tmp/key\$prod dir`) { + t.Fatalf("expected escaped dollar to be preserved, got %q", e) + } + if strings.Contains(e, `/tmp/key$prod dir`) { + t.Fatalf("expected literal dollar not to become expandable, got %q", e) + } + return + } + } + t.Fatalf("expected forced GIT_SSH_COMMAND, got %v", env) } func TestSetBatchPoolSizeUpdatesExistingAndNewPools(t *testing.T) { diff --git a/internal/gitstore/process_unix.go b/internal/gitstore/process_unix.go new file mode 100644 index 0000000..b6466d3 --- /dev/null +++ b/internal/gitstore/process_unix.go @@ -0,0 +1,23 @@ +//go:build !windows + +package gitstore + +import ( + "os/exec" + "syscall" +) + +func configureBatchCommand(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} + +func killBatchCommand(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + pid := cmd.Process.Pid + if pid > 0 { + _ = syscall.Kill(-pid, syscall.SIGKILL) + } + _ = cmd.Process.Kill() +} diff --git a/internal/gitstore/process_windows.go b/internal/gitstore/process_windows.go new file mode 100644 index 0000000..ee1e6c7 --- /dev/null +++ b/internal/gitstore/process_windows.go @@ -0,0 +1,14 @@ +//go:build windows + +package gitstore + +import "os/exec" + +func configureBatchCommand(_ *exec.Cmd) {} + +func killBatchCommand(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + _ = cmd.Process.Kill() +} diff --git a/internal/hydrator/hydrator.go b/internal/hydrator/hydrator.go index e36343c..eb55ddf 100644 --- a/internal/hydrator/hydrator.go +++ b/internal/hydrator/hydrator.go @@ -28,6 +28,7 @@ type Service struct { fetcher BlobFetcher mu sync.Mutex pq priorityQueue + queued map[string]*taskItem wait inflight[result] verifying inflight[verifyResult] started bool @@ -36,6 +37,7 @@ type Service struct { workReady chan struct{} // signaled when new work is enqueued onHydrated OnHydratedFunc verified map[string]struct{} + active map[string]struct{} } type result struct { @@ -52,11 +54,13 @@ type verifyResult struct { func New(fetcher BlobFetcher) *Service { return &Service{ fetcher: fetcher, + queued: map[string]*taskItem{}, wait: newInflight[result](), verifying: newInflight[verifyResult](), stopCh: make(chan struct{}), workReady: make(chan struct{}, 1), verified: map[string]struct{}{}, + active: map[string]struct{}{}, } } @@ -100,30 +104,90 @@ func (s *Service) Stop() { func (s *Service) Enqueue(task model.HydrationTask) { s.mu.Lock() - heap.Push(&s.pq, &taskItem{task: task}) + enqueued := s.enqueueLocked(task) s.mu.Unlock() - s.signalWork() + if enqueued { + s.signalWork() + } +} + +func (s *Service) EnqueueBatch(tasks []model.HydrationTask) { + if len(tasks) == 0 { + return + } + s.mu.Lock() + enqueued := 0 + for _, task := range tasks { + if s.enqueueLocked(task) { + enqueued++ + } + } + s.mu.Unlock() + for range enqueued { + s.signalWork() + } +} + +func (s *Service) enqueueLocked(task model.HydrationTask) bool { + key := taskKey(task.RepoID, task.ObjectOID) + if _, ok := s.active[key]; ok { + return false + } + if item, ok := s.queued[key]; ok { + if task.Priority > item.task.Priority { + item.task = task + heap.Fix(&s.pq, item.index) + return true + } + return false + } + item := &taskItem{task: task} + heap.Push(&s.pq, item) + s.queued[key] = item + return true } func (s *Service) EnsureHydrated(ctx context.Context, repo model.RepoConfig, node model.BaseNode) (cachePath string, size int64, err error) { cachePath = cachePathFor(repo, node.ObjectOID) + key := taskKey(repo.ID, node.ObjectOID) + if ch, ok := s.joinInflight(key); ok { + return s.awaitHydration(ctx, key, ch) + } if size, ok, err := s.validateCachedBlob(ctx, repo, cachePath, node); err != nil { return "", 0, err } else if ok { return cachePath, size, nil } - key := taskKey(repo.ID, node.ObjectOID) ch := make(chan result, 1) s.mu.Lock() + if _, ok := s.active[key]; ok || len(s.wait[key]) > 0 { + s.wait.add(key, ch) + s.mu.Unlock() + return s.awaitHydration(ctx, key, ch) + } first := s.wait.add(key, ch) if first { - heap.Push(&s.pq, &taskItem{task: explicitReadTask(repo.ID, node.Path, node.ObjectOID)}) + first = s.enqueueLocked(explicitReadTask(repo.ID, node)) } s.mu.Unlock() if first { s.signalWork() } + return s.awaitHydration(ctx, key, ch) +} + +func (s *Service) joinInflight(key string) (chan result, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.active[key]; !ok && len(s.wait[key]) == 0 { + return nil, false + } + ch := make(chan result, 1) + s.wait.add(key, ch) + return ch, true +} +func (s *Service) awaitHydration(ctx context.Context, key string, ch chan result) (cachePath string, size int64, err error) { select { case <-ctx.Done(): // Remove our channel from the wait list so the worker doesn't @@ -222,12 +286,6 @@ func (s *Service) isVerified(key string) bool { return ok } -func (s *Service) markVerified(key string) { - s.mu.Lock() - defer s.mu.Unlock() - s.verified[key] = struct{}{} -} - func (s *Service) clearVerified(key string) { s.mu.Lock() defer s.mu.Unlock() @@ -335,12 +393,30 @@ func (s *Service) step(repo model.RepoConfig) bool { } item := heap.Pop(&s.pq).(*taskItem) key := taskKey(item.task.RepoID, item.task.ObjectOID) - waits := s.wait.take(key) + delete(s.queued, key) + if _, ok := s.active[key]; ok { + s.mu.Unlock() + return true + } + s.active[key] = struct{}{} + more := len(s.pq) > 0 s.mu.Unlock() + if more { + s.signalWork() + } cachePath := cachePathFor(repo, item.task.ObjectOID) + node := taskBaseNode(item.task) + if size, ok, err := s.validateCachedBlob(context.Background(), repo, cachePath, node); err != nil { + s.finishHydration(key, result{err: err}) + return true + } else if ok { + s.finishHydration(key, result{cachePath: cachePath, size: size}) + s.notifyHydrated(item.task.RepoID, item.task.ObjectOID, size) + return true + } if err := os.MkdirAll(filepath.Dir(cachePath), 0o755); err != nil { - notifyWaiters(waits, result{err: err}) + s.finishHydration(key, result{err: err}) return true } // Use a timeout context derived from stopCh so stuck blob fetches don't @@ -356,18 +432,44 @@ func (s *Service) step(repo model.RepoConfig) bool { }() size, err := s.fetcher.BlobToCache(fetchCtx, repo, item.task.ObjectOID, cachePath) if err != nil { - notifyWaiters(waits, result{err: fmt.Errorf("hydrate %s: %w", item.task.Path, err)}) + s.finishHydration(key, result{err: fmt.Errorf("hydrate %s: %w", item.task.Path, err)}) return true } - s.markVerified(taskKey(item.task.RepoID, item.task.ObjectOID)) + s.mu.Lock() + s.verified[key] = struct{}{} + waits := s.wait.take(key) + delete(s.active, key) + s.mu.Unlock() notifyWaiters(waits, result{cachePath: cachePath, size: size, err: nil}) + s.notifyHydrated(item.task.RepoID, item.task.ObjectOID, size) + return true +} + +func taskBaseNode(task model.HydrationTask) model.BaseNode { + return model.BaseNode{ + RepoID: task.RepoID, + Path: task.Path, + ObjectOID: task.ObjectOID, + SizeState: task.SizeState, + SizeBytes: task.SizeBytes, + } +} + +func (s *Service) notifyHydrated(repoID model.RepoID, objectOID string, size int64) { s.mu.Lock() fn := s.onHydrated s.mu.Unlock() if fn != nil { - fn(item.task.RepoID, item.task.ObjectOID, size) + fn(repoID, objectOID, size) } - return true +} + +func (s *Service) finishHydration(key string, r result) { + s.mu.Lock() + waits := s.wait.take(key) + delete(s.active, key) + s.mu.Unlock() + notifyWaiters(waits, r) } func taskKey(repoID model.RepoID, oid string) string { @@ -378,11 +480,13 @@ func cachePathFor(repo model.RepoConfig, oid string) string { return filepath.Join(repo.BlobCacheDir, oid) } -func explicitReadTask(repoID model.RepoID, path string, oid string) model.HydrationTask { +func explicitReadTask(repoID model.RepoID, node model.BaseNode) model.HydrationTask { return model.HydrationTask{ RepoID: repoID, - Path: path, - ObjectOID: oid, + Path: node.Path, + ObjectOID: node.ObjectOID, + SizeState: node.SizeState, + SizeBytes: node.SizeBytes, Priority: PriorityExplicitRead, Reason: "explicit read", EnqueuedAt: time.Now(), diff --git a/internal/hydrator/hydrator_benchmark_test.go b/internal/hydrator/hydrator_benchmark_test.go new file mode 100644 index 0000000..855100e --- /dev/null +++ b/internal/hydrator/hydrator_benchmark_test.go @@ -0,0 +1,341 @@ +package hydrator + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/cloudflare/artifact-fs/internal/gitstore" + "github.com/cloudflare/artifact-fs/internal/model" +) + +func BenchmarkAsyncHydration(b *testing.B) { + const ( + objects = 4096 + hydratorWorkers = 8 + callerWorkers = 32 + ) + payload := []byte("small blob payload\n") + nodes := make([]model.BaseNode, objects) + for i := range nodes { + oid := fmt.Sprintf("%040x", i+1) + nodes[i] = model.BaseNode{ + RepoID: "repo", + Path: fmt.Sprintf("dir/file-%04d.txt", i), + ObjectOID: oid, + SizeState: "known", + SizeBytes: int64(len(payload)), + } + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: b.TempDir()} + fetcher := &fakeBlobFetcher{payload: payload} + h := New(fetcher) + h.Start(hydratorWorkers, cfg) + + jobs := make(chan model.BaseNode) + var wg sync.WaitGroup + for range callerWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for n := range jobs { + if _, _, err := h.EnsureHydrated(context.Background(), cfg, n); err != nil { + b.Errorf("EnsureHydrated: %v", err) + return + } + } + }() + } + for _, n := range nodes { + jobs <- n + } + close(jobs) + wg.Wait() + h.Stop() + } +} + +func BenchmarkAsyncHydrationDuplicateReads(b *testing.B) { + const ( + objects = 512 + repeats = 8 + hydratorWorkers = 8 + callerWorkers = 32 + ) + payload := []byte("small blob payload\n") + nodes := make([]model.BaseNode, 0, objects*repeats) + for i := range objects { + oid := fmt.Sprintf("%040x", i+1) + for r := range repeats { + nodes = append(nodes, model.BaseNode{ + RepoID: "repo", + Path: fmt.Sprintf("dir/file-%04d-%02d.txt", i, r), + ObjectOID: oid, + SizeState: "known", + SizeBytes: int64(len(payload)), + }) + } + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: b.TempDir()} + fetcher := &fakeBlobFetcher{payload: payload, fetchDelay: 100 * time.Microsecond} + h := New(fetcher) + h.Start(hydratorWorkers, cfg) + hydrateBenchmarkNodes(b, h, cfg, nodes, callerWorkers) + h.Stop() + } +} + +func BenchmarkQueuedDuplicatePrefetch(b *testing.B) { + const ( + objects = 512 + repeats = 8 + hydratorWorkers = 8 + ) + payload := []byte("small blob payload\n") + tasks := make([]model.HydrationTask, 0, objects*repeats) + for i := range objects { + oid := fmt.Sprintf("%040x", i+1) + for r := range repeats { + tasks = append(tasks, model.HydrationTask{ + RepoID: "repo", + Path: fmt.Sprintf("dir/file-%04d-%02d.txt", i, r), + ObjectOID: oid, + Priority: PriorityLikelyText, + Reason: "prefetch", + EnqueuedAt: time.Now(), + }) + } + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: b.TempDir()} + fetcher := &fakeBlobFetcher{payload: payload} + h := New(fetcher) + for _, task := range tasks { + h.Enqueue(task) + } + h.Start(hydratorWorkers, cfg) + waitBenchmarkHydratorIdle(b, h, cfg.ID) + h.Stop() + b.ReportMetric(float64(fetcher.Calls()), "fetches/op") + } +} + +func BenchmarkQueuedCachedPrefetch(b *testing.B) { + const ( + objects = 512 + hydratorWorkers = 8 + ) + payload := []byte("small blob payload\n") + tasks := make([]model.HydrationTask, 0, objects) + for i := range objects { + oid := fmt.Sprintf("%040x", i+1) + tasks = append(tasks, model.HydrationTask{ + RepoID: "repo", + Path: fmt.Sprintf("dir/file-%04d.txt", i), + ObjectOID: oid, + Priority: PriorityLikelyText, + Reason: "prefetch", + EnqueuedAt: time.Now(), + }) + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: b.TempDir()} + if err := os.MkdirAll(cfg.BlobCacheDir, 0o755); err != nil { + b.Fatal(err) + } + for _, task := range tasks { + if err := os.WriteFile(filepath.Join(cfg.BlobCacheDir, task.ObjectOID), payload, 0o644); err != nil { + b.Fatal(err) + } + } + fetcher := &fakeBlobFetcher{payload: []byte("new payload"), verifyOK: true} + h := New(fetcher) + for _, task := range tasks { + h.Enqueue(task) + } + h.Start(hydratorWorkers, cfg) + waitBenchmarkHydratorIdle(b, h, cfg.ID) + h.Stop() + b.ReportMetric(float64(fetcher.Calls()), "fetches/op") + b.ReportMetric(float64(fetcher.VerifyCalls()), "verifies/op") + } +} + +func BenchmarkQueuedPrefetchWorkerRamp(b *testing.B) { + const ( + objects = 8 + hydratorWorkers = 8 + ) + payload := []byte("small blob payload\n") + tasks := make([]model.HydrationTask, 0, objects) + for i := range objects { + tasks = append(tasks, model.HydrationTask{ + RepoID: "repo", + Path: fmt.Sprintf("dir/file-%04d.txt", i), + ObjectOID: fmt.Sprintf("%040x", i+1), + Priority: PriorityLikelyText, + Reason: "prefetch", + EnqueuedAt: time.Now(), + }) + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: b.TempDir()} + fetcher := &fakeBlobFetcher{payload: payload, fetchDelay: 10 * time.Millisecond} + h := New(fetcher) + for _, task := range tasks { + h.Enqueue(task) + } + h.Start(hydratorWorkers, cfg) + waitBenchmarkHydratorIdle(b, h, cfg.ID) + h.Stop() + b.ReportMetric(float64(fetcher.Calls()), "fetches/op") + } +} + +func BenchmarkAsyncHydrationGitStore(b *testing.B) { + const ( + objects = 2048 + hydratorWorkers = 8 + callerWorkers = 32 + ) + + workDir, gitDir := createBenchmarkGitRepo(b, objects) + cfg := model.RepoConfig{ID: "repo", GitDir: gitDir} + git := gitstore.New(nil) + b.Cleanup(func() { + git.Close() + if err := os.RemoveAll(workDir); err != nil { + b.Errorf("remove benchmark repo: %v", err) + } + }) + git.SetBatchPoolSize(hydratorWorkers) + + head, _, err := git.ResolveHEAD(context.Background(), cfg) + if err != nil { + b.Fatalf("ResolveHEAD: %v", err) + } + nodes, err := git.BuildTreeIndex(context.Background(), cfg, head) + if err != nil { + b.Fatalf("BuildTreeIndex: %v", err) + } + targets := make([]model.BaseNode, 0, len(nodes)) + for _, n := range nodes { + if n.Type == "file" && n.ObjectOID != "" { + targets = append(targets, n) + } + } + if len(targets) != objects { + b.Fatalf("targets = %d, want %d", len(targets), objects) + } + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + cfg.BlobCacheDir = b.TempDir() + h := New(git) + h.Start(hydratorWorkers, cfg) + hydrateBenchmarkNodes(b, h, cfg, targets, callerWorkers) + h.Stop() + } +} + +func hydrateBenchmarkNodes(b *testing.B, h *Service, cfg model.RepoConfig, nodes []model.BaseNode, callerWorkers int) { + b.Helper() + jobs := make(chan model.BaseNode) + var wg sync.WaitGroup + for range callerWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for n := range jobs { + if _, _, err := h.EnsureHydrated(context.Background(), cfg, n); err != nil { + b.Errorf("EnsureHydrated: %v", err) + return + } + } + }() + } + for _, n := range nodes { + jobs <- n + } + close(jobs) + wg.Wait() +} + +func waitBenchmarkHydratorIdle(b *testing.B, h *Service, repoID model.RepoID) { + b.Helper() + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + h.mu.Lock() + queued := 0 + for _, item := range h.pq { + if item.task.RepoID == repoID { + queued++ + } + } + idle := queued == 0 && len(h.active) == 0 + h.mu.Unlock() + if idle { + return + } + time.Sleep(1 * time.Millisecond) + } + b.Fatalf("hydrator did not become idle") +} + +func createBenchmarkGitRepo(b *testing.B, objects int) (workDir string, gitDir string) { + b.Helper() + workDir, err := os.MkdirTemp("", "artifact-fs-hydrator-bench-") + if err != nil { + b.Fatal(err) + } + runBenchmarkGit(b, workDir, "init") + runBenchmarkGit(b, workDir, "config", "user.name", "Hydrator Bench") + runBenchmarkGit(b, workDir, "config", "user.email", "hydrator-bench@example.com") + for i := range objects { + dir := filepath.Join(workDir, fmt.Sprintf("dir-%02d", i%16)) + if err := os.MkdirAll(dir, 0o755); err != nil { + b.Fatal(err) + } + path := filepath.Join(dir, fmt.Sprintf("file-%04d.txt", i)) + data := []byte(fmt.Sprintf("blob payload %04d\n", i)) + if err := os.WriteFile(path, data, 0o644); err != nil { + b.Fatal(err) + } + } + runBenchmarkGit(b, workDir, "add", ".") + runBenchmarkGit(b, workDir, "commit", "-m", "add benchmark blobs") + return workDir, filepath.Join(workDir, ".git") +} + +func runBenchmarkGit(b *testing.B, dir string, args ...string) { + b.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + b.Fatalf("git %v: %v\n%s", args, err, out) + } +} diff --git a/internal/hydrator/hydrator_test.go b/internal/hydrator/hydrator_test.go index 31cbd99..26d27ef 100644 --- a/internal/hydrator/hydrator_test.go +++ b/internal/hydrator/hydrator_test.go @@ -202,6 +202,143 @@ func TestEnsureHydratedVerifiesUnknownCacheHitOnce(t *testing.T) { } } +func TestEnsureHydratedJoinsActiveFetch(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + payload := []byte("content") + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: tmp} + node := model.BaseNode{RepoID: cfg.ID, Path: "file.txt", ObjectOID: "blob", SizeState: "known", SizeBytes: int64(len(payload))} + releaseFetch := make(chan struct{}) + fetchStarted := make(chan struct{}) + fetcher := &fakeBlobFetcher{payload: payload, fetchStarted: fetchStarted, fetchWait: releaseFetch} + h := New(fetcher) + h.Start(1, cfg) + defer h.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + const readers = 8 + errCh := make(chan error, readers+1) + go func() { + _, _, err := h.EnsureHydrated(ctx, cfg, node) + errCh <- err + }() + <-fetchStarted + for range readers { + go func() { + _, _, err := h.EnsureHydrated(ctx, cfg, node) + errCh <- err + }() + } + runtime.Gosched() + close(releaseFetch) + for i := 0; i < readers+1; i++ { + if err := <-errCh; err != nil { + t.Fatalf("EnsureHydrated: %v", err) + } + } + if fetcher.Calls() != 1 { + t.Fatalf("fetch calls = %d, want 1", fetcher.Calls()) + } +} + +func TestEnqueueDedupesAndUpgradesPriority(t *testing.T) { + t.Parallel() + h := New(&fakeBlobFetcher{}) + low := model.HydrationTask{RepoID: "repo", Path: "image.png", ObjectOID: "blob", Priority: PriorityBinary, Reason: "prefetch", EnqueuedAt: time.Now()} + high := model.HydrationTask{RepoID: "repo", Path: "README.md", ObjectOID: "blob", Priority: PriorityBootstrap, Reason: "prefetch", EnqueuedAt: time.Now().Add(time.Second)} + + h.Enqueue(low) + h.Enqueue(low) + if got := h.QueueDepth("repo"); got != 1 { + t.Fatalf("QueueDepth after duplicate enqueue = %d, want 1", got) + } + h.EnqueueBatch([]model.HydrationTask{low, high}) + if got := h.QueueDepth("repo"); got != 1 { + t.Fatalf("QueueDepth after priority upgrade = %d, want 1", got) + } + h.mu.Lock() + got := h.pq[0].task + h.mu.Unlock() + if got.Priority != PriorityBootstrap || got.Path != "README.md" { + t.Fatalf("queued task = %+v, want upgraded README priority", got) + } +} + +func TestQueuedHydrationUsesValidCacheHit(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + payload := []byte("content") + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: tmp} + cachePath := filepath.Join(tmp, "blob") + if err := os.WriteFile(cachePath, payload, 0o644); err != nil { + t.Fatal(err) + } + fetcher := &fakeBlobFetcher{payload: []byte("new-data"), verifyOK: true} + h := New(fetcher) + hydrated := make(chan struct{}) + var hydratedOnce sync.Once + h.SetOnHydrated(func(model.RepoID, string, int64) { + hydratedOnce.Do(func() { close(hydrated) }) + }) + h.Enqueue(model.HydrationTask{ + RepoID: cfg.ID, + Path: "file.txt", + ObjectOID: "blob", + SizeState: "known", + SizeBytes: int64(len(payload)), + Priority: PriorityBootstrap, + Reason: "prefetch", + EnqueuedAt: time.Now(), + }) + h.Start(1, cfg) + defer h.Stop() + + select { + case <-hydrated: + case <-time.After(2 * time.Second): + t.Fatal("queued hydration did not complete") + } + if fetcher.Calls() != 0 { + t.Fatalf("fetch calls = %d, want 0", fetcher.Calls()) + } + if fetcher.VerifyCalls() != 1 { + t.Fatalf("verify calls = %d, want 1", fetcher.VerifyCalls()) + } + data, err := os.ReadFile(cachePath) + if err != nil { + t.Fatal(err) + } + if string(data) != string(payload) { + t.Fatalf("cache contents = %q, want %q", data, payload) + } +} + +func TestQueuedHydrationWakesWorkersForBacklog(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: tmp} + releaseFetch := make(chan struct{}) + fetcher := &fakeBlobFetcher{payload: []byte("content"), fetchWait: releaseFetch} + h := New(fetcher) + for i := range 4 { + h.Enqueue(model.HydrationTask{ + RepoID: cfg.ID, + Path: filepath.Join("dir", "file.txt"), + ObjectOID: string(rune('a' + i)), + Priority: PriorityBootstrap, + Reason: "prefetch", + EnqueuedAt: time.Now(), + }) + } + h.Start(4, cfg) + defer h.Stop() + + waitForFetchCalls(t, fetcher, 4) + close(releaseFetch) + waitForHydratorIdle(t, h, cfg.ID) +} + func TestEnsureHydratedVerificationIgnoresLeaderTimeout(t *testing.T) { t.Parallel() tmp := t.TempDir() @@ -362,6 +499,39 @@ func TestReadBlobSkipsVerificationForOversizedCache(t *testing.T) { } } +func waitForFetchCalls(t *testing.T, fetcher *fakeBlobFetcher, want int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if fetcher.Calls() >= want { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("fetch calls = %d, want at least %d", fetcher.Calls(), want) +} + +func waitForHydratorIdle(t *testing.T, h *Service, repoID model.RepoID) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + h.mu.Lock() + queued := 0 + for _, item := range h.pq { + if item.task.RepoID == repoID { + queued++ + } + } + idle := queued == 0 && len(h.active) == 0 + h.mu.Unlock() + if idle { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("hydrator did not become idle") +} + type fakeBlobFetcher struct { mu sync.Mutex calls int @@ -371,8 +541,12 @@ type fakeBlobFetcher struct { readBlobErr error verifyOK bool verifyErr error + fetchStarted chan struct{} + fetchWait <-chan struct{} + fetchDelay time.Duration verifyStarted chan struct{} verifyWait <-chan struct{} + fetchOnce sync.Once verifyOnce sync.Once } @@ -380,6 +554,15 @@ func (f *fakeBlobFetcher) BlobToCache(_ context.Context, _ model.RepoConfig, _ s f.mu.Lock() f.calls++ f.mu.Unlock() + if f.fetchStarted != nil { + f.fetchOnce.Do(func() { close(f.fetchStarted) }) + } + if f.fetchWait != nil { + <-f.fetchWait + } + if f.fetchDelay > 0 { + time.Sleep(f.fetchDelay) + } if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil { return 0, err } diff --git a/internal/model/types.go b/internal/model/types.go index 7897453..034f92b 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -28,6 +28,10 @@ type RepoConfig struct { MetaDBPath string OverlayDBPath string Enabled bool + PreparedGitDir bool + FetchRef string + PrepareState string + PrepareError string } type RepoRuntimeState struct { @@ -44,8 +48,15 @@ type RepoRuntimeState struct { HydratedBlobBytes int64 DirtyOverlay bool State string + PrepareError string } +const ( + PrepareStatePreparing = "preparing" + PrepareStateReady = "ready" + PrepareStateFailed = "failed" +) + // BaseNode represents a tracked entry from the git tree. Inode IDs are assigned // at runtime by the FUSE layer (monotonic allocation, like tigrisfs). type BaseNode struct { @@ -101,6 +112,8 @@ type HydrationTask struct { RepoID RepoID Path string ObjectOID string + SizeState string + SizeBytes int64 Priority int Reason string EnqueuedAt time.Time @@ -143,7 +156,9 @@ type Registry interface { type GitStore interface { CloneBlobless(ctx context.Context, cfg RepoConfig) error + CloneBloblessNonInteractive(ctx context.Context, cfg RepoConfig) error Fetch(ctx context.Context, repo RepoConfig) error + FetchRefNonInteractive(ctx context.Context, repo RepoConfig, ref string) error ResolveHEAD(ctx context.Context, repo RepoConfig) (oid string, ref string, err error) BuildTreeIndex(ctx context.Context, repo RepoConfig, headOID string) ([]BaseNode, error) BlobToCache(ctx context.Context, repo RepoConfig, objectOID string, dstPath string) (size int64, err error) @@ -151,6 +166,8 @@ type GitStore interface { ComputeAheadBehind(ctx context.Context, repo RepoConfig) (ahead int, behind int, diverged bool, err error) CommitTimestamp(ctx context.Context, repo RepoConfig, oid string) (int64, error) ReadTreeHEAD(ctx context.Context, repo RepoConfig) error + PrepareFetchedBranch(ctx context.Context, repo RepoConfig, ref string) error + ValidatePreparedGitDir(ctx context.Context, repo RepoConfig) error } type SnapshotStore interface { @@ -177,6 +194,7 @@ type OverlayStore interface { type Hydrator interface { Enqueue(task HydrationTask) + EnqueueBatch(tasks []HydrationTask) EnsureHydrated(ctx context.Context, repo RepoConfig, node BaseNode) (cachePath string, size int64, err error) ReadBlob(ctx context.Context, repo RepoConfig, node BaseNode, maxBytes int64) ([]byte, error) QueueDepth(repoID RepoID) int diff --git a/internal/registry/registry.go b/internal/registry/registry.go index e07e766..6ff1b39 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -11,12 +11,15 @@ import ( "github.com/cloudflare/artifact-fs/internal/model" ) +var ErrRepoChanged = errors.New("repo config changed") + var migrations = []string{ `CREATE TABLE IF NOT EXISTS repos ( repo_id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, mount_root TEXT NOT NULL, mount_path TEXT NOT NULL, + remote_url TEXT NOT NULL DEFAULT '', remote_url_redacted TEXT NOT NULL, remote_url_secret_ref TEXT, branch TEXT NOT NULL, @@ -27,6 +30,10 @@ var migrations = []string{ meta_db_path TEXT NOT NULL, overlay_db_path TEXT NOT NULL, enabled INTEGER NOT NULL DEFAULT 1, + prepared_gitdir INTEGER NOT NULL DEFAULT 0, + fetch_ref TEXT NOT NULL DEFAULT '', + prepare_state TEXT NOT NULL DEFAULT '', + prepare_error TEXT NOT NULL DEFAULT '', created_at_ns INTEGER NOT NULL, updated_at_ns INTEGER NOT NULL );`, @@ -44,6 +51,10 @@ func New(ctx context.Context, dbPath string) (*Store, error) { if err := meta.ExecMigrations(ctx, db, migrations); err != nil { return nil, err } + if err := ensureRepoColumns(ctx, db); err != nil { + db.Close() + return nil, err + } return &Store{db: db}, nil } @@ -52,12 +63,13 @@ func (s *Store) Close() error { return s.db.Close() } func (s *Store) AddRepo(ctx context.Context, cfg model.RepoConfig) error { now := time.Now().UnixNano() _, err := s.db.ExecContext(ctx, ` - INSERT INTO repos (repo_id, name, mount_root, mount_path, remote_url_redacted, branch, refresh_interval_seconds, git_dir, overlay_dir, blob_cache_dir, meta_db_path, overlay_db_path, enabled, created_at_ns, updated_at_ns) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO repos (repo_id, name, mount_root, mount_path, remote_url, remote_url_redacted, branch, refresh_interval_seconds, git_dir, overlay_dir, blob_cache_dir, meta_db_path, overlay_db_path, enabled, prepared_gitdir, fetch_ref, prepare_state, prepare_error, created_at_ns, updated_at_ns) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(repo_id) DO UPDATE SET name=excluded.name, mount_root=excluded.mount_root, mount_path=excluded.mount_path, + remote_url=excluded.remote_url, remote_url_redacted=excluded.remote_url_redacted, branch=excluded.branch, refresh_interval_seconds=excluded.refresh_interval_seconds, @@ -67,23 +79,65 @@ func (s *Store) AddRepo(ctx context.Context, cfg model.RepoConfig) error { meta_db_path=excluded.meta_db_path, overlay_db_path=excluded.overlay_db_path, enabled=excluded.enabled, + prepared_gitdir=excluded.prepared_gitdir, + fetch_ref=excluded.fetch_ref, + prepare_state=excluded.prepare_state, + prepare_error=excluded.prepare_error, updated_at_ns=excluded.updated_at_ns - `, string(cfg.ID), cfg.Name, cfg.MountRoot, cfg.MountPath, cfg.RemoteURLRedacted, cfg.Branch, int64(cfg.RefreshInterval.Seconds()), cfg.GitDir, cfg.OverlayDir, cfg.BlobCacheDir, cfg.MetaDBPath, cfg.OverlayDBPath, boolToInt(cfg.Enabled), now, now) + `, string(cfg.ID), cfg.Name, cfg.MountRoot, cfg.MountPath, cfg.RemoteURL, cfg.RemoteURLRedacted, cfg.Branch, int64(cfg.RefreshInterval.Seconds()), cfg.GitDir, cfg.OverlayDir, cfg.BlobCacheDir, cfg.MetaDBPath, cfg.OverlayDBPath, boolToInt(cfg.Enabled), boolToInt(cfg.PreparedGitDir), cfg.FetchRef, cfg.PrepareState, cfg.PrepareError, now, now) + return err +} + +func (s *Store) UpdatePrepareState(ctx context.Context, repoID model.RepoID, state string, prepareErr string) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE repos + SET prepare_state=?, prepare_error=?, updated_at_ns=? + WHERE repo_id=? + `, state, prepareErr, time.Now().UnixNano(), string(repoID)) return err } +func (s *Store) UpdatePrepareStateForConfig(ctx context.Context, cfg model.RepoConfig, state string, prepareErr string) error { + res, err := s.db.ExecContext(ctx, ` + UPDATE repos + SET prepare_state=?, prepare_error=?, updated_at_ns=? + WHERE repo_id=? + AND branch=? + AND remote_url=? + AND prepared_gitdir=? + AND fetch_ref=? + AND git_dir=? + AND overlay_dir=? + AND blob_cache_dir=? + AND meta_db_path=? + AND overlay_db_path=? + AND mount_path=? + `, state, prepareErr, time.Now().UnixNano(), string(cfg.ID), cfg.Branch, cfg.RemoteURL, boolToInt(cfg.PreparedGitDir), cfg.FetchRef, cfg.GitDir, cfg.OverlayDir, cfg.BlobCacheDir, cfg.MetaDBPath, cfg.OverlayDBPath, cfg.MountPath) + if err != nil { + return err + } + rows, err := res.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return ErrRepoChanged + } + return nil +} + func (s *Store) RemoveRepo(ctx context.Context, name string) error { _, err := s.db.ExecContext(ctx, `DELETE FROM repos WHERE name=?`, name) return err } func (s *Store) GetRepo(ctx context.Context, name string) (model.RepoConfig, error) { - row := s.db.QueryRowContext(ctx, `SELECT repo_id, name, mount_root, mount_path, remote_url_redacted, branch, refresh_interval_seconds, git_dir, overlay_dir, blob_cache_dir, meta_db_path, overlay_db_path, enabled FROM repos WHERE name=?`, name) + row := s.db.QueryRowContext(ctx, `SELECT repo_id, name, mount_root, mount_path, remote_url, remote_url_redacted, branch, refresh_interval_seconds, git_dir, overlay_dir, blob_cache_dir, meta_db_path, overlay_db_path, enabled, prepared_gitdir, fetch_ref, prepare_state, prepare_error FROM repos WHERE name=?`, name) return scanRepo(row) } func (s *Store) ListRepos(ctx context.Context) ([]model.RepoConfig, error) { - rows, err := s.db.QueryContext(ctx, `SELECT repo_id, name, mount_root, mount_path, remote_url_redacted, branch, refresh_interval_seconds, git_dir, overlay_dir, blob_cache_dir, meta_db_path, overlay_db_path, enabled FROM repos ORDER BY name`) + rows, err := s.db.QueryContext(ctx, `SELECT repo_id, name, mount_root, mount_path, remote_url, remote_url_redacted, branch, refresh_interval_seconds, git_dir, overlay_dir, blob_cache_dir, meta_db_path, overlay_db_path, enabled, prepared_gitdir, fetch_ref, prepare_state, prepare_error FROM repos ORDER BY name`) if err != nil { return nil, err } @@ -107,7 +161,8 @@ func scanRepo(s scanner) (model.RepoConfig, error) { var cfg model.RepoConfig var refresh int64 var enabled int - if err := s.Scan(&cfg.ID, &cfg.Name, &cfg.MountRoot, &cfg.MountPath, &cfg.RemoteURLRedacted, &cfg.Branch, &refresh, &cfg.GitDir, &cfg.OverlayDir, &cfg.BlobCacheDir, &cfg.MetaDBPath, &cfg.OverlayDBPath, &enabled); err != nil { + var preparedGitDir int + if err := s.Scan(&cfg.ID, &cfg.Name, &cfg.MountRoot, &cfg.MountPath, &cfg.RemoteURL, &cfg.RemoteURLRedacted, &cfg.Branch, &refresh, &cfg.GitDir, &cfg.OverlayDir, &cfg.BlobCacheDir, &cfg.MetaDBPath, &cfg.OverlayDBPath, &enabled, &preparedGitDir, &cfg.FetchRef, &cfg.PrepareState, &cfg.PrepareError); err != nil { if errors.Is(err, sql.ErrNoRows) { return cfg, fmt.Errorf("repo not found") } @@ -115,6 +170,7 @@ func scanRepo(s scanner) (model.RepoConfig, error) { } cfg.RefreshInterval = time.Duration(refresh) * time.Second cfg.Enabled = enabled == 1 + cfg.PreparedGitDir = preparedGitDir == 1 return cfg, nil } @@ -124,3 +180,43 @@ func boolToInt(v bool) int { } return 0 } + +func ensureRepoColumns(ctx context.Context, db *sql.DB) error { + rows, err := db.QueryContext(ctx, `PRAGMA table_info(repos)`) + if err != nil { + return err + } + defer rows.Close() + cols := map[string]bool{} + for rows.Next() { + var cid int + var name string + var typ string + var notNull int + var defaultValue any + var pk int + if err := rows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil { + return err + } + cols[name] = true + } + if err := rows.Err(); err != nil { + return err + } + add := map[string]string{ + "remote_url": `TEXT NOT NULL DEFAULT ''`, + "prepared_gitdir": `INTEGER NOT NULL DEFAULT 0`, + "fetch_ref": `TEXT NOT NULL DEFAULT ''`, + "prepare_state": `TEXT NOT NULL DEFAULT ''`, + "prepare_error": `TEXT NOT NULL DEFAULT ''`, + } + for name, ddl := range add { + if cols[name] { + continue + } + if _, err := db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE repos ADD COLUMN %s %s`, name, ddl)); err != nil { + return err + } + } + return nil +} diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go new file mode 100644 index 0000000..5295361 --- /dev/null +++ b/internal/registry/registry_test.go @@ -0,0 +1,62 @@ +package registry + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/cloudflare/artifact-fs/internal/model" +) + +func TestRepoPrepareFieldsRoundTrip(t *testing.T) { + ctx := context.Background() + store, err := New(ctx, filepath.Join(t.TempDir(), "repos.sqlite")) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + cfg := model.RepoConfig{ + ID: "repo", + Name: "repo", + MountRoot: "/mnt", + MountPath: "/mnt/repo", + RemoteURL: "https://github.com/example/repo.git", + RemoteURLRedacted: "https://github.com/example/repo.git", + Branch: "master", + RefreshInterval: time.Minute, + GitDir: "/git/repo", + OverlayDir: "/overlay/repo", + BlobCacheDir: "/cache/repo", + MetaDBPath: "/meta/repo.sqlite", + OverlayDBPath: "/overlay/repo/meta.sqlite", + Enabled: true, + PreparedGitDir: true, + FetchRef: "master", + PrepareState: model.PrepareStatePreparing, + } + if err := store.AddRepo(ctx, cfg); err != nil { + t.Fatal(err) + } + if err := store.UpdatePrepareState(ctx, cfg.ID, model.PrepareStateFailed, "clone failed"); err != nil { + t.Fatal(err) + } + + got, err := store.GetRepo(ctx, "repo") + if err != nil { + t.Fatal(err) + } + if !got.PreparedGitDir { + t.Fatal("PreparedGitDir = false, want true") + } + if got.FetchRef != "master" { + t.Fatalf("FetchRef = %q, want master", got.FetchRef) + } + if got.PrepareState != model.PrepareStateFailed { + t.Fatalf("PrepareState = %q, want failed", got.PrepareState) + } + if got.PrepareError != "clone failed" { + t.Fatalf("PrepareError = %q, want clone failed", got.PrepareError) + } +} diff --git a/internal/snapshot/store.go b/internal/snapshot/store.go index ea3435d..123d42b 100644 --- a/internal/snapshot/store.go +++ b/internal/snapshot/store.go @@ -28,6 +28,7 @@ var migrations = []string{ PRIMARY KEY (generation, path) );`, `CREATE INDEX IF NOT EXISTS idx_base_nodes_gen_parent ON base_nodes(generation, parent_path);`, + `CREATE INDEX IF NOT EXISTS idx_base_nodes_gen_oid ON base_nodes(generation, object_oid);`, `DROP TABLE IF EXISTS learned_path_stats;`, `DROP TABLE IF EXISTS blob_cache_index;`, } @@ -164,7 +165,29 @@ func (s *Store) ListChildren(generation int64, parentPath string) ([]model.BaseN // the given OID in the current generation so stat() returns the correct size // without waiting for a full re-index. func (s *Store) UpdateSize(generation int64, objectOID string, size int64) { - s.db.Exec(`UPDATE base_nodes SET size_bytes=?, size_state='known' WHERE generation=? AND object_oid=?`, size, generation, objectOID) + _ = s.UpdateSizes(context.Background(), generation, map[string]int64{objectOID: size}) +} + +func (s *Store) UpdateSizes(ctx context.Context, generation int64, sizes map[string]int64) error { + if len(sizes) == 0 { + return nil + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + stmt, err := tx.PrepareContext(ctx, `UPDATE base_nodes SET size_bytes=?, size_state='known' WHERE generation=? AND object_oid=?`) + if err != nil { + return err + } + defer stmt.Close() + for oid, size := range sizes { + if _, err := stmt.ExecContext(ctx, size, generation, oid); err != nil { + return err + } + } + return tx.Commit() } func (s *Store) nextGenerationTx(ctx context.Context, tx *sql.Tx) (int64, error) { diff --git a/internal/snapshot/store_test.go b/internal/snapshot/store_test.go index 8137eb6..56c30e0 100644 --- a/internal/snapshot/store_test.go +++ b/internal/snapshot/store_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "path/filepath" + "strings" "testing" "github.com/cloudflare/artifact-fs/internal/meta" @@ -147,6 +148,49 @@ func TestCurrentGeneration(t *testing.T) { } } +func TestUpdateSizeUsesGenerationObjectOIDIndex(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + nodes := []model.BaseNode{ + {RepoID: "r", Path: ".", Type: "dir", Mode: 0o755, SizeState: "known"}, + {RepoID: "r", Path: "a.txt", Type: "file", Mode: 0o644, ObjectOID: "shared", SizeState: "unknown"}, + {RepoID: "r", Path: "b.txt", Type: "file", Mode: 0o644, ObjectOID: "other", SizeState: "unknown"}, + } + gen, err := s.PublishGeneration(ctx, "h1", "main", nodes) + if err != nil { + t.Fatal(err) + } + + var id, parent, notUsed int + var detail string + err = s.db.QueryRowContext(ctx, `EXPLAIN QUERY PLAN UPDATE base_nodes SET size_bytes=?, size_state='known' WHERE generation=? AND object_oid=?`, 42, gen, "shared").Scan(&id, &parent, ¬Used, &detail) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(detail, "idx_base_nodes_gen_oid") { + t.Fatalf("UpdateSize plan = %q, want idx_base_nodes_gen_oid", detail) + } + + if err := s.UpdateSizes(ctx, gen, map[string]int64{"shared": 42, "other": 7}); err != nil { + t.Fatal(err) + } + n, ok := s.GetNode(gen, "a.txt") + if !ok { + t.Fatal("a.txt not found") + } + if n.SizeState != "known" || n.SizeBytes != 42 { + t.Fatalf("a.txt size = %s/%d, want known/42", n.SizeState, n.SizeBytes) + } + n, ok = s.GetNode(gen, "b.txt") + if !ok { + t.Fatal("b.txt not found") + } + if n.SizeState != "known" || n.SizeBytes != 7 { + t.Fatalf("b.txt size = %s/%d, want known/7", n.SizeState, n.SizeBytes) + } +} + func TestNewDropsLegacyTables(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "snap.sqlite")