diff --git a/README.md b/README.md index d5eac52..a560e67 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,39 @@ The CLI uses Notion's remote MCP server with OAuth authentication. On first run, **Note:** Access tokens expire after 1 hour. The CLI automatically refreshes tokens when they expire or are about to expire, so you typically don't need to think about this. Use `notion-cli auth refresh` to manually refresh if needed. +### Profiles + +Every command accepts `--profile ` (or `NOTION_CLI_PROFILE`) to scope the OAuth token and official API config to a specific Notion account, so you can keep separate logins for `work`, `home`, etc. + +```bash +# Log in to a named profile +notion-cli auth login --profile work + +# Use the profile for a single command +notion-cli page list --profile work + +# Pin a profile for the shell session +export NOTION_CLI_PROFILE=work + +# Make a profile the default for future invocations +notion-cli auth use work +``` + +Profile resolution, highest priority first: + +1. `--profile ` flag +2. `NOTION_CLI_PROFILE` environment variable +3. Active profile from `notion-cli auth use ` +4. Implicit default profile + +The default profile keeps using the existing OAuth token path, so existing single-account installs need no migration. `notion-cli auth use ` stores the active profile in the cross-profile config directory. + +Profile names must start and end with a lowercase ASCII letter or number. They may contain lowercase letters, numbers, at signs, dots, underscores, and hyphens. + +Named profiles store their credentials under `~/.config/notion-cli/profiles//{token,config}.json`. + +`notion-cli auth status` prints the selected profile and token path, and `notion-cli auth list` shows all known profiles with OAuth and API-token status. + ## Environment Variables | Variable | Description | @@ -182,6 +215,7 @@ The CLI uses Notion's remote MCP server with OAuth authentication. On first run, | `NOTION_API_TOKEN` | Official Notion API token used for upload fallback and verification | | `NOTION_API_BASE_URL` | Override the official Notion API base URL | | `NOTION_API_NOTION_VERSION` | Override the official Notion API version | +| `NOTION_CLI_PROFILE` | Default profile when `--profile` is not passed | ## How It Works diff --git a/cmd/auth.go b/cmd/auth.go index 963f11f..285b1bd 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -9,6 +9,7 @@ import ( "os" "regexp" "strings" + "time" "github.com/fatih/color" "github.com/lox/notion-cli/internal/cli" @@ -22,6 +23,8 @@ type AuthCmd struct { Login AuthLoginCmd `cmd:"" help:"Authenticate with Notion via OAuth"` Refresh AuthRefreshCmd `cmd:"" help:"Refresh the access token"` Status AuthStatusCmd `cmd:"" default:"withargs" help:"Show authentication status"` + List AuthListCmd `cmd:"" help:"List profiles and authentication state"` + Use AuthUseCmd `cmd:"" help:"Set the active profile"` Logout AuthLogoutCmd `cmd:"" help:"Clear stored credentials"` API AuthAPICmd `cmd:"" name:"api" help:"Official API token commands"` } @@ -34,10 +37,21 @@ var notionAPITokenPattern = regexp.MustCompile(`^ntn_[A-Za-z0-9]{20,}$`) const officialAPIIntegrationsURL = "https://www.notion.so/profile/integrations/internal" +type authProfileStatus struct { + Profile string `json:"profile"` + Active bool `json:"active"` + HasOAuthToken bool `json:"has_oauth_token"` + OAuthStatus string `json:"oauth_status"` + OAuthExpiresAt *time.Time `json:"oauth_expires_at,omitempty"` + HasAPIToken bool `json:"has_api_token"` + TokenPath string `json:"token_path"` + ConfigPath string `json:"config_path"` +} + type AuthLoginCmd struct{} func (c *AuthLoginCmd) Run(ctx *Context) error { - tokenStore, err := mcp.NewFileTokenStore() + tokenStore, err := mcp.NewFileTokenStore(ctx.Profile) if err != nil { output.PrintError(err) return err @@ -55,7 +69,7 @@ func (c *AuthLoginCmd) Run(ctx *Context) error { type AuthRefreshCmd struct{} func (c *AuthRefreshCmd) Run(ctx *Context) error { - tokenStore, err := mcp.NewFileTokenStore() + tokenStore, err := mcp.NewFileTokenStore(ctx.Profile) if err != nil { output.PrintError(err) return err @@ -95,63 +109,147 @@ type AuthStatusCmd struct { func (c *AuthStatusCmd) Run(ctx *Context) error { ctx.JSON = c.JSON - tokenStore, err := mcp.NewFileTokenStore() + status, err := inspectProfileStatus(ctx.Profile) if err != nil { output.PrintError(err) return err } - token, err := tokenStore.GetToken(context.Background()) - if err != nil { - if err == mcp.ErrNoToken { - fmt.Println("Not authenticated. Run 'notion-cli auth login' to authenticate.") - return nil + if ctx.JSON { + payload := map[string]any{ + "authenticated": status.OAuthStatus == "valid", + "profile": status.Profile, + "active": status.Active, + "has_oauth_token": status.HasOAuthToken, + "oauth_status": status.OAuthStatus, + "has_api_token": status.HasAPIToken, + "token_path": status.TokenPath, + "config_path": status.ConfigPath, + } + if status.OAuthExpiresAt != nil { + payload["oauth_expires_at"] = status.OAuthExpiresAt } + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(payload) + } + + labelStyle := color.New(color.Faint) + switch status.OAuthStatus { + case "valid": + output.PrintSuccess("Authenticated") + case "login_required": + output.PrintWarning("Login required") + default: + output.PrintWarning("Not authenticated") + } + fmt.Println() + + _, _ = labelStyle.Print("Profile: ") + fmt.Println(status.Profile) + _, _ = labelStyle.Print("Token path: ") + fmt.Println(status.TokenPath) + + if status.OAuthExpiresAt != nil { + _, _ = labelStyle.Print("Expires: ") + fmt.Println(status.OAuthExpiresAt.Format("2 Jan 2006 15:04")) + } + if status.OAuthStatus == "missing" || status.OAuthStatus == "login_required" { + _, _ = fmt.Fprintln(os.Stdout, "Run 'notion-cli auth login' to authenticate this profile.") + } + + return nil +} + +type AuthListCmd struct { + JSON bool `help:"Output as JSON" short:"j"` +} + +func (c *AuthListCmd) Run(ctx *Context) error { + profiles, err := config.ListProfiles() + if err != nil { output.PrintError(err) return err } - hasValidToken := token.AccessToken != "" && !token.IsExpired() + rows := make([]authProfileStatus, 0, len(profiles)) + for _, profile := range profiles { + row, err := inspectProfileStatus(profile) + if err != nil { + output.PrintError(err) + return err + } + rows = append(rows, row) + } - if ctx.JSON { + if c.JSON { enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") - return enc.Encode(map[string]any{ - "authenticated": hasValidToken, - "token_type": token.TokenType, - "has_token": token.AccessToken != "", - "expires_at": token.ExpiresAt, - "config_path": tokenStore.Path(), - }) + return enc.Encode(rows) } labelStyle := color.New(color.Faint) - - if hasValidToken { - output.PrintSuccess("Authenticated") - } else { - output.PrintWarning("Token expired or not set") + for i, row := range rows { + header := row.Profile + if row.Active { + header += " (active)" + } + fmt.Println(header) + _, _ = labelStyle.Print(" OAuth: ") + fmt.Println(row.OAuthStatus) + if row.OAuthExpiresAt != nil { + _, _ = labelStyle.Print(" Expires: ") + fmt.Println(row.OAuthExpiresAt.Format("2 Jan 2006 15:04")) + } + _, _ = labelStyle.Print(" API token: ") + if row.HasAPIToken { + fmt.Println("configured") + } else { + fmt.Println("missing") + } + _, _ = labelStyle.Print(" Token path: ") + fmt.Println(row.TokenPath) + _, _ = labelStyle.Print(" Config path: ") + fmt.Println(row.ConfigPath) + if i < len(rows)-1 { + fmt.Println() + } } - fmt.Println() - _, _ = labelStyle.Print("Config path: ") - fmt.Println(tokenStore.Path()) + return nil +} - _, _ = labelStyle.Print("Token type: ") - fmt.Println(token.TokenType) +type AuthUseCmd struct { + Profile string `arg:"" help:"Profile name to make active"` +} - if !token.ExpiresAt.IsZero() { - _, _ = labelStyle.Print("Expires: ") - fmt.Println(token.ExpiresAt.Format("2 Jan 2006 15:04")) +func (c *AuthUseCmd) Run(ctx *Context) error { + if err := config.SetActiveProfile(c.Profile); err != nil { + output.PrintError(err) + return err } + status, err := inspectProfileStatus(c.Profile) + if err != nil { + output.PrintError(err) + return err + } + + output.PrintSuccess("Active profile updated") + fmt.Printf("Profile: %s\n", status.Profile) + if status.OAuthStatus == "missing" || status.OAuthStatus == "login_required" { + fmt.Println("Run 'notion-cli auth login' to authenticate this profile.") + } + if !status.HasAPIToken { + fmt.Println("Run 'notion-cli auth api setup' if this profile needs official API features.") + } return nil } type AuthLogoutCmd struct{} func (c *AuthLogoutCmd) Run(ctx *Context) error { - tokenStore, err := mcp.NewFileTokenStore() + tokenStore, err := mcp.NewFileTokenStore(ctx.Profile) if err != nil { output.PrintError(err) return err @@ -190,13 +288,13 @@ func (c *AuthAPISetupCmd) Run(ctx *Context) error { output.PrintWarning("Official API token does not match the expected Notion token format") _, _ = fmt.Fprintln(authAPIOutput, "Expected format: ntn_") } - if err := config.SetAPIToken(token); err != nil { + if err := config.SetAPITokenForProfile(ctx.Profile, token); err != nil { output.PrintError(err) return err } output.PrintSuccess("Official API token saved") - _, _ = fmt.Fprintf(authAPIOutput, "Config path: %s\n", mustConfigPath()) + _, _ = fmt.Fprintf(authAPIOutput, "Config path: %s\n", mustConfigPath(ctx.Profile)) return nil } @@ -244,6 +342,7 @@ func (c *AuthAPIVerifyCmd) Run(ctx *Context) error { enc.SetIndent("", " ") return enc.Encode(map[string]any{ "verified": true, + "profile": loaded.Profile, "token_source": loaded.APITokenSource, "config_path": loaded.ConfigPath, "base_url": loaded.Config.API.BaseURL, @@ -253,6 +352,7 @@ func (c *AuthAPIVerifyCmd) Run(ctx *Context) error { } output.PrintSuccess("Official API token verified") + _, _ = fmt.Fprintf(authAPIOutput, "Profile: %s\n", loaded.Profile) _, _ = fmt.Fprintf(authAPIOutput, "Token source: %s\n", loaded.APITokenSource) _, _ = fmt.Fprintf(authAPIOutput, "Config path: %s\n", loaded.ConfigPath) _, _ = fmt.Fprintf(authAPIOutput, "Base URL: %s\n", loaded.Config.API.BaseURL) @@ -285,7 +385,7 @@ func (c *AuthAPIUnsetCmd) Run(ctx *Context) error { return nil } - if err := config.UnsetAPIToken(); err != nil { + if err := config.UnsetAPITokenForProfile(ctx.Profile); err != nil { output.PrintError(err) return err } @@ -306,6 +406,7 @@ func printAuthAPIStatus(ctx *Context, loaded *cli.OfficialAPIConfig) error { enc.SetIndent("", " ") return enc.Encode(map[string]any{ "configured": hasToken, + "profile": loaded.Profile, "token_source": loaded.APITokenSource, "config_path": loaded.ConfigPath, "base_url": loaded.Config.API.BaseURL, @@ -319,6 +420,7 @@ func printAuthAPIStatus(ctx *Context, loaded *cli.OfficialAPIConfig) error { output.PrintWarning("Official API token not configured") } _, _ = fmt.Fprintln(authAPIOutput) + _, _ = fmt.Fprintf(authAPIOutput, "Profile: %s\n", loaded.Profile) _, _ = fmt.Fprintf(authAPIOutput, "Token source: %s\n", loaded.APITokenSource) _, _ = fmt.Fprintf(authAPIOutput, "Config path: %s\n", loaded.ConfigPath) _, _ = fmt.Fprintf(authAPIOutput, "Base URL: %s\n", loaded.Config.API.BaseURL) @@ -374,10 +476,69 @@ func printOfficialAPITokenSetupHint(out io.Writer, shouldOpenBrowser bool) { _, _ = fmt.Fprintln(out) } -func mustConfigPath() string { - path, err := config.Path() +func mustConfigPath(profile string) string { + path, err := config.PathForProfile(profile) if err != nil { return "" } return path } + +func inspectProfileStatus(profile string) (authProfileStatus, error) { + resolvedProfile, err := config.ResolveProfile(profile) + if err != nil { + return authProfileStatus{}, err + } + active, err := config.ActiveProfile() + if err != nil { + return authProfileStatus{}, err + } + paths, err := config.PathsForProfile(resolvedProfile) + if err != nil { + return authProfileStatus{}, err + } + loaded, err := config.LoadWithMeta(config.APIOverrides{Profile: resolvedProfile}) + if err != nil { + return authProfileStatus{}, err + } + + status := authProfileStatus{ + Profile: resolvedProfile, + Active: resolvedProfile == active, + HasAPIToken: loaded.HasConfigToken, + TokenPath: paths.TokenPath, + ConfigPath: paths.ConfigPath, + OAuthStatus: "missing", + } + + tokenStore, err := mcp.NewFileTokenStore(resolvedProfile) + if err != nil { + return authProfileStatus{}, err + } + token, err := tokenStore.GetToken(context.Background()) + if err != nil { + if err == mcp.ErrNoToken { + return status, nil + } + return authProfileStatus{}, err + } + + status.HasOAuthToken = strings.TrimSpace(token.AccessToken) != "" + if status.HasOAuthToken { + expiresAt := token.ExpiresAt + status.OAuthExpiresAt = &expiresAt + switch { + case !token.IsExpired(): + status.OAuthStatus = "valid" + case strings.TrimSpace(token.RefreshToken) != "": + // Access token is past expiry but a refresh token is on file, + // so the next command will silently refresh. Surface this as + // valid; the underlying expiry is still in OAuthExpiresAt for + // callers that want it. + status.OAuthStatus = "valid" + default: + status.OAuthStatus = "login_required" + } + } + return status, nil +} diff --git a/cmd/auth_api_test.go b/cmd/auth_api_test.go index 9a4452e..37f59d6 100644 --- a/cmd/auth_api_test.go +++ b/cmd/auth_api_test.go @@ -138,12 +138,15 @@ func TestAuthAPIStatusJSONUsesLoadedConfig(t *testing.T) { }) cmd := &AuthAPIStatusCmd{JSON: true} - if err := cmd.Run(&Context{APIToken: "env-token"}); err != nil { + if err := cmd.Run(&Context{Profile: "work", APIToken: "env-token"}); err != nil { t.Fatalf("Run: %v", err) } if !strings.Contains(out.String(), `"configured": true`) { t.Fatalf("unexpected output: %s", out.String()) } + if !strings.Contains(out.String(), `"profile": "work"`) { + t.Fatalf("unexpected output: %s", out.String()) + } if !strings.Contains(out.String(), `"token_source": "env"`) { t.Fatalf("unexpected output: %s", out.String()) } diff --git a/cmd/auth_test.go b/cmd/auth_test.go new file mode 100644 index 0000000..7ef6b44 --- /dev/null +++ b/cmd/auth_test.go @@ -0,0 +1,199 @@ +package cmd + +import ( + "context" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/lox/notion-cli/internal/config" + "github.com/lox/notion-cli/internal/mcp" + "github.com/mark3labs/mcp-go/client/transport" +) + +func isolateAuthConfig(t *testing.T) { + t.Helper() + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) +} + +func TestAuthUsePersistsActiveProfile(t *testing.T) { + isolateAuthConfig(t) + + cmd := &AuthUseCmd{Profile: "work"} + stdout := captureStdout(t, func() { + if err := cmd.Run(&Context{}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + + active, err := config.ActiveProfile() + if err != nil { + t.Fatalf("ActiveProfile: %v", err) + } + if active != "work" { + t.Fatalf("active profile = %q, want work", active) + } + if !strings.Contains(stdout, "Profile: work") { + t.Fatalf("unexpected output: %q", stdout) + } +} + +func TestAuthListJSONShowsProfilesAndActiveState(t *testing.T) { + isolateAuthConfig(t) + + if err := config.SetActiveProfile("work"); err != nil { + t.Fatalf("SetActiveProfile: %v", err) + } + if err := config.SetAPITokenForProfile("personal", "personal-token"); err != nil { + t.Fatalf("SetAPITokenForProfile: %v", err) + } + store, err := mcp.NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "oauth-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + cmd := &AuthListCmd{JSON: true} + stdout := captureStdout(t, func() { + if err := cmd.Run(&Context{}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + + if !strings.Contains(stdout, `"profile": "work"`) { + t.Fatalf("unexpected output: %s", stdout) + } + if !strings.Contains(stdout, `"active": true`) { + t.Fatalf("unexpected output: %s", stdout) + } + if !strings.Contains(stdout, `"oauth_status": "valid"`) { + t.Fatalf("unexpected output: %s", stdout) + } + if !strings.Contains(stdout, `"profile": "personal"`) { + t.Fatalf("unexpected output: %s", stdout) + } +} + +func TestAuthStatusJSONReportsMissingTokenForProfile(t *testing.T) { + isolateAuthConfig(t) + + cmd := &AuthStatusCmd{JSON: true} + stdout := captureStdout(t, func() { + if err := cmd.Run(&Context{Profile: "work"}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + + if !strings.Contains(stdout, `"profile": "work"`) { + t.Fatalf("unexpected output: %s", stdout) + } + if !strings.Contains(stdout, `"oauth_status": "missing"`) { + t.Fatalf("unexpected output: %s", stdout) + } +} + +func TestAuthStatusJSONReportsValidWhenAccessExpiredButRefreshAvailable(t *testing.T) { + isolateAuthConfig(t) + + store, err := mcp.NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "stale-access", + TokenType: "Bearer", + RefreshToken: "rotating-refresh", + ExpiresAt: time.Now().Add(-time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + cmd := &AuthStatusCmd{JSON: true} + stdout := captureStdout(t, func() { + if err := cmd.Run(&Context{Profile: "work"}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + + if !strings.Contains(stdout, `"oauth_status": "valid"`) { + t.Fatalf("expected valid status when refresh token present, got: %s", stdout) + } + if !strings.Contains(stdout, `"authenticated": true`) { + t.Fatalf("expected authenticated true when refresh token present, got: %s", stdout) + } + if !strings.Contains(stdout, `"has_oauth_token": true`) { + t.Fatalf("expected has_oauth_token field, got: %s", stdout) + } + if !strings.Contains(stdout, `"oauth_expires_at":`) { + t.Fatalf("expected oauth_expires_at field, got: %s", stdout) + } +} + +func TestAuthStatusJSONOmitsExpiryWhenAccessTokenMissing(t *testing.T) { + isolateAuthConfig(t) + + store, err := mcp.NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + TokenType: "Bearer", + RefreshToken: "leftover-refresh", + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + cmd := &AuthStatusCmd{JSON: true} + stdout := captureStdout(t, func() { + if err := cmd.Run(&Context{Profile: "work"}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + + if !strings.Contains(stdout, `"oauth_status": "missing"`) { + t.Fatalf("expected missing status when access token empty, got: %s", stdout) + } + if strings.Contains(stdout, `"oauth_expires_at"`) { + t.Fatalf("expected no oauth_expires_at when access token empty, got: %s", stdout) + } +} + +func TestAuthStatusJSONReportsLoginRequiredWithoutRefreshToken(t *testing.T) { + isolateAuthConfig(t) + + store, err := mcp.NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "stale-access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + cmd := &AuthStatusCmd{JSON: true} + stdout := captureStdout(t, func() { + if err := cmd.Run(&Context{Profile: "work"}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + + if !strings.Contains(stdout, `"oauth_status": "login_required"`) { + t.Fatalf("expected login_required without refresh token, got: %s", stdout) + } + if !strings.Contains(stdout, `"authenticated": false`) { + t.Fatalf("expected authenticated false without refresh token, got: %s", stdout) + } +} diff --git a/cmd/root.go b/cmd/root.go index f5c1db5..3e63ade 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,6 +3,7 @@ package cmd import "github.com/lox/notion-cli/internal/config" type Context struct { + Profile string JSON bool Token string APIToken string @@ -11,6 +12,7 @@ type Context struct { } type CLI struct { + Profile string `help:"Config profile name" env:"NOTION_CLI_PROFILE"` Token string `help:"Access token (skips OAuth)" env:"NOTION_ACCESS_TOKEN" hidden:""` APIToken string `env:"NOTION_API_TOKEN" hidden:""` APIBaseURL string `env:"NOTION_API_BASE_URL" hidden:""` @@ -39,6 +41,7 @@ func officialAPIOverrides(ctx *Context) config.APIOverrides { return config.APIOverrides{} } return config.APIOverrides{ + Profile: ctx.Profile, BaseURL: ctx.APIBaseURL, NotionVersion: ctx.APINotionVersion, Token: ctx.APIToken, diff --git a/internal/cli/context.go b/internal/cli/context.go index d69b867..6794d10 100644 --- a/internal/cli/context.go +++ b/internal/cli/context.go @@ -12,12 +12,17 @@ import ( ) var accessToken string +var profile string var authRefreshNoticeWriter io.Writer = os.Stderr func SetAccessToken(token string) { accessToken = token } +func SetProfile(value string) { + profile = value +} + func GetClient() (*mcp.Client, error) { ctx := context.Background() @@ -33,6 +38,7 @@ func GetClient() (*mcp.Client, error) { if accessToken != "" { opts = append(opts, mcp.WithAccessToken(accessToken)) } + opts = append(opts, mcp.WithProfile(profile)) client, err := mcp.NewClient(opts...) if err != nil { @@ -51,7 +57,7 @@ func GetClient() (*mcp.Client, error) { } func autoRefreshIfNeeded(ctx context.Context) error { - tokenStore, err := mcp.NewFileTokenStore() + tokenStore, err := mcp.NewFileTokenStore(profile) if err != nil { return err } @@ -67,7 +73,7 @@ func autoRefreshIfNeeded(ctx context.Context) error { return fmt.Errorf("token expired and no refresh token available") } - _, err := mcp.RefreshToken(ctx, tokenStore) + _, err := mcp.RefreshTokenIfNeeded(ctx, tokenStore) if err != nil { return fmt.Errorf("auto-refresh failed: %w", err) } diff --git a/internal/cli/official_api.go b/internal/cli/official_api.go index f159fef..452fc8b 100644 --- a/internal/cli/official_api.go +++ b/internal/cli/official_api.go @@ -9,6 +9,7 @@ import ( type OfficialAPIConfig struct { Config config.Config + Profile string ConfigPath string APITokenSource string HasConfigToken bool @@ -21,6 +22,7 @@ func LoadOfficialAPIConfig(overrides config.APIOverrides) (*OfficialAPIConfig, e } return &OfficialAPIConfig{ Config: loaded.Config, + Profile: loaded.Profile, ConfigPath: loaded.Path, APITokenSource: loaded.APITokenSource, HasConfigToken: loaded.HasConfigToken, diff --git a/internal/config/config.go b/internal/config/config.go index 1c7c6b1..3173109 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,12 +6,17 @@ import ( "fmt" "os" "path/filepath" + "slices" "strings" ) const ( configDirName = "notion-cli" configFileName = "config.json" + tokenFileName = "token.json" + stateFileName = "state.json" + profilesDirName = "profiles" + defaultProfileName = "default" defaultAPIBaseURL = "https://api.notion.com/v1" defaultNotionAPIVer = "2026-03-11" ) @@ -28,17 +33,29 @@ type APIConfig struct { type LoadedConfig struct { Config Config + Profile string Path string APITokenSource string HasConfigToken bool } type APIOverrides struct { + Profile string BaseURL string NotionVersion string Token string } +type ProfilePaths struct { + Profile string + ConfigPath string + TokenPath string +} + +type State struct { + ActiveProfile string `json:"active_profile,omitempty"` +} + const ( APITokenSourceNone = "none" APITokenSourceConfig = "config" @@ -55,11 +72,329 @@ func Default() Config { } func Path() (string, error) { + return PathForProfile("") +} + +func ConfigDir() (string, error) { configDir, err := os.UserConfigDir() if err != nil { return "", err } - return filepath.Join(configDir, configDirName, configFileName), nil + return filepath.Join(configDir, configDirName), nil +} + +func ProfilesDir() (string, error) { + baseDir, err := ConfigDir() + if err != nil { + return "", err + } + return filepath.Join(baseDir, profilesDirName), nil +} + +func StatePath() (string, error) { + baseDir, err := ConfigDir() + if err != nil { + return "", err + } + return filepath.Join(baseDir, stateFileName), nil +} + +func PathForProfile(profile string) (string, error) { + resolvedProfile, err := ResolveProfile(profile) + if err != nil { + return "", err + } + profileDir, err := profileBaseDir(resolvedProfile) + if err != nil { + return "", err + } + return filepath.Join(profileDir, configFileName), nil +} + +func PathsForProfile(profile string) (ProfilePaths, error) { + resolvedProfile, err := ResolveProfile(profile) + if err != nil { + return ProfilePaths{}, err + } + + profileDir, err := profileBaseDir(resolvedProfile) + if err != nil { + return ProfilePaths{}, err + } + + tokenPath := filepath.Join(profileDir, tokenFileName) + if resolvedProfile == defaultProfileName { + tokenPath, err = legacyDefaultTokenPath() + if err != nil { + return ProfilePaths{}, err + } + } + + return ProfilePaths{ + Profile: resolvedProfile, + ConfigPath: filepath.Join(profileDir, configFileName), + TokenPath: tokenPath, + }, nil +} + +// profileBaseDir returns the directory that holds a profile's config file. +// It only depends on ConfigDir (XDG_CONFIG_HOME or its platform fallback) +// and never resolves HOME, so config-only flows keep working in environments +// where os.UserHomeDir is unavailable (e.g. minimal CI containers). +func profileBaseDir(resolvedProfile string) (string, error) { + baseDir, err := ConfigDir() + if err != nil { + return "", err + } + if resolvedProfile == defaultProfileName { + return baseDir, nil + } + return filepath.Join(baseDir, profilesDirName, resolvedProfile), nil +} + +// legacyDefaultTokenPath preserves the pre-profile OAuth token location. +// API config files use ConfigDir, but upstream OAuth tokens historically +// lived under ~/.config/notion-cli even on platforms where os.UserConfigDir +// resolves somewhere else. +func legacyDefaultTokenPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".config", configDirName, tokenFileName), nil +} + +func DefaultProfile() string { + return defaultProfileName +} + +func ResolveProfile(profile string) (string, error) { + normalized := strings.TrimSpace(profile) + if normalized == "" { + return defaultProfileName, nil + } + + runes := []rune(normalized) + if !isProfileEndpoint(runes[0]) || !isProfileEndpoint(runes[len(runes)-1]) { + return "", fmt.Errorf("invalid profile %q: start and end with a lowercase letter or number", profile) + } + for _, r := range runes { + if isProfileChar(r) { + continue + } + return "", fmt.Errorf("invalid profile %q: use lowercase letters, numbers, at sign, dot, underscore, and hyphen", profile) + } + if isWindowsReservedName(normalized) { + return "", fmt.Errorf("invalid profile %q: name is reserved on Windows", profile) + } + return normalized, nil +} + +// isWindowsReservedName reports whether the (already lowercased) profile +// matches a Windows reserved device name. Windows refuses to create files or +// directories with these basenames, so allowing them here would leave +// profiles that work on macOS/Linux but break on Windows the moment a +// command tries to read or write the profile's token/config files. +func isWindowsReservedName(name string) bool { + base := name + if dot := strings.IndexByte(base, '.'); dot >= 0 { + base = base[:dot] + } + switch base { + case "con", "prn", "aux", "nul": + return true + } + if len(base) == 4 && (base[:3] == "com" || base[:3] == "lpt") { + c := base[3] + if c >= '1' && c <= '9' { + return true + } + } + return false +} + +func isProfileEndpoint(r rune) bool { + return isLowercaseASCII(r) || isDigitASCII(r) +} + +func isProfileChar(r rune) bool { + return isLowercaseASCII(r) || isDigitASCII(r) || r == '.' || r == '_' || r == '-' || r == '@' +} + +func isLowercaseASCII(r rune) bool { + return r >= 'a' && r <= 'z' +} + +func isDigitASCII(r rune) bool { + return r >= '0' && r <= '9' +} + +func LoadState() (State, error) { + path, err := StatePath() + if err != nil { + return State{}, err + } + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return State{}, nil + } + return State{}, fmt.Errorf("read state: %w", err) + } + if len(data) == 0 { + return State{}, nil + } + var state State + if err := json.Unmarshal(data, &state); err != nil { + return State{}, fmt.Errorf("parse state: %w", err) + } + if state.ActiveProfile != "" { + resolved, err := ResolveProfile(state.ActiveProfile) + if err != nil { + return State{}, err + } + state.ActiveProfile = resolved + } + return state, nil +} + +func SaveState(state State) error { + if state.ActiveProfile != "" { + resolved, err := ResolveProfile(state.ActiveProfile) + if err != nil { + return err + } + state.ActiveProfile = resolved + } + + path, err := StatePath() + if err != nil { + return err + } + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("create state dir: %w", err) + } + if err := os.Chmod(dir, 0o700); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("secure state dir: %w", err) + } + + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return fmt.Errorf("marshal state: %w", err) + } + data = append(data, '\n') + + tmp, err := os.CreateTemp(dir, stateFileName+".*.tmp") + if err != nil { + return fmt.Errorf("create temp state: %w", err) + } + + tmpPath := tmp.Name() + cleanup := func() { + _ = tmp.Close() + _ = os.Remove(tmpPath) + } + if err := tmp.Chmod(0o600); err != nil { + cleanup() + return fmt.Errorf("secure temp state: %w", err) + } + if _, err := tmp.Write(data); err != nil { + cleanup() + return fmt.Errorf("write temp state: %w", err) + } + if err := tmp.Close(); err != nil { + cleanup() + return fmt.Errorf("close temp state: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + cleanup() + return fmt.Errorf("replace state: %w", err) + } + if err := os.Chmod(path, 0o600); err != nil { + return fmt.Errorf("secure state file: %w", err) + } + return nil +} + +func ResolveSelectedProfile(requested string) (string, error) { + if strings.TrimSpace(requested) != "" { + return ResolveProfile(requested) + } + state, err := LoadState() + if err != nil { + return "", err + } + if strings.TrimSpace(state.ActiveProfile) != "" { + return ResolveProfile(state.ActiveProfile) + } + return DefaultProfile(), nil +} + +func SetActiveProfile(profile string) error { + resolved, err := ResolveProfile(profile) + if err != nil { + return err + } + return SaveState(State{ActiveProfile: resolved}) +} + +func ActiveProfile() (string, error) { + state, err := LoadState() + if err != nil { + return "", err + } + if strings.TrimSpace(state.ActiveProfile) != "" { + return state.ActiveProfile, nil + } + return DefaultProfile(), nil +} + +func ListProfiles() ([]string, error) { + baseDir, err := ProfilesDir() + if err != nil { + return nil, err + } + + names := map[string]struct{}{ + DefaultProfile(): {}, + } + active, err := ActiveProfile() + if err != nil { + return nil, err + } + names[active] = struct{}{} + + entries, err := os.ReadDir(baseDir) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("read profiles dir: %w", err) + } + for _, entry := range entries { + if !entry.IsDir() { + continue + } + resolved, err := ResolveProfile(entry.Name()) + if err != nil { + return nil, err + } + names[resolved] = struct{}{} + } + + var rest []string + for name := range names { + if name == active || name == DefaultProfile() { + continue + } + rest = append(rest, name) + } + slices.Sort(rest) + + profiles := []string{active} + if active != DefaultProfile() { + profiles = append(profiles, DefaultProfile()) + } + profiles = append(profiles, rest...) + return profiles, nil } func Load() (Config, error) { @@ -72,10 +407,11 @@ func Load() (Config, error) { func LoadWithMeta(overrides APIOverrides) (LoadedConfig, error) { cfg := Default() - path, err := Path() + paths, err := PathsForProfile(overrides.Profile) if err != nil { return LoadedConfig{}, err } + path := paths.ConfigPath fileCfg, err := loadFile(path) if err != nil { @@ -95,6 +431,7 @@ func LoadWithMeta(overrides APIOverrides) (LoadedConfig, error) { normalize(&cfg) return LoadedConfig{ Config: cfg, + Profile: paths.Profile, Path: path, APITokenSource: source, HasConfigToken: strings.TrimSpace(fileCfg.API.Token) != "", @@ -102,7 +439,11 @@ func LoadWithMeta(overrides APIOverrides) (LoadedConfig, error) { } func Save(cfg Config) error { - path, err := Path() + return SaveForProfile("", cfg) +} + +func SaveForProfile(profile string, cfg Config) error { + path, err := PathForProfile(profile) if err != nil { return err } @@ -156,25 +497,33 @@ func Save(cfg Config) error { } func SetAPIToken(token string) error { - cfg, err := loadForMutation() + return SetAPITokenForProfile("", token) +} + +func SetAPITokenForProfile(profile, token string) error { + cfg, err := loadForMutation(profile) if err != nil { return err } cfg.API.Token = strings.TrimSpace(token) - return Save(cfg) + return SaveForProfile(profile, cfg) } func UnsetAPIToken() error { - cfg, err := loadForMutation() + return UnsetAPITokenForProfile("") +} + +func UnsetAPITokenForProfile(profile string) error { + cfg, err := loadForMutation(profile) if err != nil { return err } cfg.API.Token = "" - return Save(cfg) + return SaveForProfile(profile, cfg) } -func loadForMutation() (Config, error) { - path, err := Path() +func loadForMutation(profile string) (Config, error) { + path, err := PathForProfile(profile) if err != nil { return Config{}, err } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f566b4b..c71afa0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,11 +3,21 @@ package config import ( "os" "path/filepath" + "reflect" + "strings" "testing" ) +func isolateConfigDir(t *testing.T) { + t.Helper() + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) +} + func TestLoadWithMetaDefaults(t *testing.T) { - t.Setenv("HOME", t.TempDir()) + isolateConfigDir(t) loaded, err := LoadWithMeta(APIOverrides{}) if err != nil { @@ -22,10 +32,13 @@ func TestLoadWithMetaDefaults(t *testing.T) { if loaded.APITokenSource != APITokenSourceNone { t.Fatalf("APITokenSource = %q, want %q", loaded.APITokenSource, APITokenSourceNone) } + if loaded.Profile != DefaultProfile() { + t.Fatalf("Profile = %q, want %q", loaded.Profile, DefaultProfile()) + } } func TestLoadWithMetaReportsConfigTokenSource(t *testing.T) { - t.Setenv("HOME", t.TempDir()) + isolateConfigDir(t) if err := SetAPIToken("secret-token"); err != nil { t.Fatalf("SetAPIToken: %v", err) } @@ -43,7 +56,7 @@ func TestLoadWithMetaReportsConfigTokenSource(t *testing.T) { } func TestLoadWithMetaEnvOverrideWins(t *testing.T) { - t.Setenv("HOME", t.TempDir()) + isolateConfigDir(t) if err := SetAPIToken("config-token"); err != nil { t.Fatalf("SetAPIToken: %v", err) } @@ -70,7 +83,7 @@ func TestLoadWithMetaEnvOverrideWins(t *testing.T) { } func TestUnsetAPITokenClearsStoredToken(t *testing.T) { - t.Setenv("HOME", t.TempDir()) + isolateConfigDir(t) if err := SetAPIToken("secret-token"); err != nil { t.Fatalf("SetAPIToken: %v", err) } @@ -91,7 +104,7 @@ func TestUnsetAPITokenClearsStoredToken(t *testing.T) { } func TestSaveSecuresConfigFile(t *testing.T) { - t.Setenv("HOME", t.TempDir()) + isolateConfigDir(t) cfg := Default() cfg.API.Token = "secret-token" @@ -120,3 +133,192 @@ func TestSaveSecuresConfigFile(t *testing.T) { t.Fatalf("config dir perm = %o, want 700", perm) } } + +func TestPathsForProfileDefaultAndNamed(t *testing.T) { + isolateConfigDir(t) + + defaultPaths, err := PathsForProfile("") + if err != nil { + t.Fatalf("PathsForProfile default: %v", err) + } + if defaultPaths.Profile != DefaultProfile() { + t.Fatalf("default profile = %q, want %q", defaultPaths.Profile, DefaultProfile()) + } + if got := filepath.Base(defaultPaths.ConfigPath); got != configFileName { + t.Fatalf("default config filename = %q, want %q", got, configFileName) + } + if got := filepath.Base(defaultPaths.TokenPath); got != tokenFileName { + t.Fatalf("default token filename = %q, want %q", got, tokenFileName) + } + if !strings.Contains(defaultPaths.TokenPath, filepath.Join(".config", configDirName, tokenFileName)) { + t.Fatalf("default token path = %q, want legacy .config/notion-cli path", defaultPaths.TokenPath) + } + + workPaths, err := PathsForProfile("work") + if err != nil { + t.Fatalf("PathsForProfile work: %v", err) + } + if workPaths.Profile != "work" { + t.Fatalf("work profile = %q, want work", workPaths.Profile) + } + if !strings.Contains(workPaths.ConfigPath, filepath.Join(profilesDirName, "work")) { + t.Fatalf("work config path = %q, want profiles/work segment", workPaths.ConfigPath) + } + if !strings.Contains(workPaths.TokenPath, filepath.Join(profilesDirName, "work")) { + t.Fatalf("work token path = %q, want profiles/work segment", workPaths.TokenPath) + } +} + +func TestProfileSpecificAPITokensAreIsolated(t *testing.T) { + isolateConfigDir(t) + + if err := SetAPIToken("default-token"); err != nil { + t.Fatalf("SetAPIToken default: %v", err) + } + if err := SetAPITokenForProfile("work", "work-token"); err != nil { + t.Fatalf("SetAPITokenForProfile: %v", err) + } + + defaultLoaded, err := LoadWithMeta(APIOverrides{}) + if err != nil { + t.Fatalf("LoadWithMeta default: %v", err) + } + if defaultLoaded.Config.API.Token != "default-token" { + t.Fatalf("default token = %q, want default-token", defaultLoaded.Config.API.Token) + } + + workLoaded, err := LoadWithMeta(APIOverrides{Profile: "work"}) + if err != nil { + t.Fatalf("LoadWithMeta work: %v", err) + } + if workLoaded.Config.API.Token != "work-token" { + t.Fatalf("work token = %q, want work-token", workLoaded.Config.API.Token) + } + if workLoaded.Path == defaultLoaded.Path { + t.Fatalf("profile config path should differ from default path") + } +} + +func TestResolveProfileRejectsInvalidNames(t *testing.T) { + for _, value := range []string{ + "../oops", + "work/team", + "two words", + "Work", + ".work", + "work.", + "work-", + "ümlaut", + } { + if _, err := ResolveProfile(value); err == nil { + t.Fatalf("ResolveProfile(%q) should fail", value) + } + } +} + +func TestResolveProfileRejectsWindowsReservedNames(t *testing.T) { + for _, value := range []string{ + "con", + "prn", + "aux", + "nul", + "com1", + "com9", + "lpt1", + "lpt9", + "con.txt", + } { + if _, err := ResolveProfile(value); err == nil { + t.Fatalf("ResolveProfile(%q) should fail as Windows reserved", value) + } + } +} + +func TestResolveProfileAllowsPortableNames(t *testing.T) { + for _, value := range []string{"work", "work-1", "personal_2", "brian@brianle.xyz"} { + got, err := ResolveProfile(value) + if err != nil { + t.Fatalf("ResolveProfile(%q): %v", value, err) + } + if got != value { + t.Fatalf("profile = %q, want %q", got, value) + } + } +} + +func TestResolveSelectedProfileUsesActiveStateWhenUnset(t *testing.T) { + isolateConfigDir(t) + if err := SetActiveProfile("work"); err != nil { + t.Fatalf("SetActiveProfile: %v", err) + } + + profile, err := ResolveSelectedProfile("") + if err != nil { + t.Fatalf("ResolveSelectedProfile: %v", err) + } + if profile != "work" { + t.Fatalf("profile = %q, want work", profile) + } +} + +func TestResolveSelectedProfilePrefersExplicitValue(t *testing.T) { + isolateConfigDir(t) + if err := SetActiveProfile("work"); err != nil { + t.Fatalf("SetActiveProfile: %v", err) + } + + profile, err := ResolveSelectedProfile("personal") + if err != nil { + t.Fatalf("ResolveSelectedProfile: %v", err) + } + if profile != "personal" { + t.Fatalf("profile = %q, want personal", profile) + } +} + +func TestListProfilesIncludesActiveDefaultAndNamedProfiles(t *testing.T) { + isolateConfigDir(t) + if err := SetActiveProfile("work"); err != nil { + t.Fatalf("SetActiveProfile: %v", err) + } + if err := SetAPITokenForProfile("personal", "personal-token"); err != nil { + t.Fatalf("SetAPITokenForProfile: %v", err) + } + + got, err := ListProfiles() + if err != nil { + t.Fatalf("ListProfiles: %v", err) + } + want := []string{"work", "default", "personal"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("profiles = %#v, want %#v", got, want) + } +} + +func TestPathForProfileDefaultDoesNotRequireHome(t *testing.T) { + isolateConfigDir(t) + t.Setenv("HOME", "") + + // macOS and Windows resolve UserConfigDir from HOME, so this regression + // only surfaces on platforms (Linux containers, etc.) where ConfigDir + // can be satisfied via XDG_CONFIG_HOME without HOME. + if _, err := os.UserConfigDir(); err != nil { + t.Skipf("UserConfigDir requires HOME on this platform: %v", err) + } + + path, err := PathForProfile("") + if err != nil { + t.Fatalf("PathForProfile(default): %v", err) + } + if path == "" { + t.Fatalf("expected non-empty config path for default profile") + } + + path, err = PathForProfile("work") + if err != nil { + t.Fatalf("PathForProfile(work): %v", err) + } + if path == "" { + t.Fatalf("expected non-empty config path for named profile") + } +} diff --git a/internal/mcp/client.go b/internal/mcp/client.go index a536b48..038d380 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -31,6 +31,7 @@ type ClientOption func(*clientConfig) type clientConfig struct { endpoint string accessToken string + profile string } func WithEndpoint(endpoint string) ClientOption { @@ -45,6 +46,12 @@ func WithAccessToken(token string) ClientOption { } } +func WithProfile(profile string) ClientOption { + return func(c *clientConfig) { + c.profile = profile + } +} + func NewClient(opts ...ClientOption) (*Client, error) { cfg := &clientConfig{ endpoint: DefaultEndpoint, @@ -53,7 +60,7 @@ func NewClient(opts ...ClientOption) (*Client, error) { opt(cfg) } - tokenStore, err := NewFileTokenStore() + tokenStore, err := NewFileTokenStore(cfg.profile) if err != nil { return nil, fmt.Errorf("create token store: %w", err) } @@ -190,9 +197,8 @@ func (c *Client) Search(ctx context.Context, query string, opts *SearchOptions) } func buildSearchToolArgs(query string, opts *SearchOptions) map[string]any { - args := map[string]any{} - if strings.TrimSpace(query) != "" { - args["query"] = query + args := map[string]any{ + "query": strings.TrimSpace(query), } if opts != nil && opts.ContentSearchMode != "" { args["content_search_mode"] = opts.ContentSearchMode diff --git a/internal/mcp/client_test.go b/internal/mcp/client_test.go index da811d1..5723a33 100644 --- a/internal/mcp/client_test.go +++ b/internal/mcp/client_test.go @@ -5,9 +5,10 @@ import ( "testing" ) -func TestBuildSearchToolArgsOmitsBlankQuery(t *testing.T) { +func TestBuildSearchToolArgsIncludesBlankQuery(t *testing.T) { got := buildSearchToolArgs("", &SearchOptions{ContentSearchMode: "workspace_search"}) want := map[string]any{ + "query": "", "content_search_mode": "workspace_search", } diff --git a/internal/mcp/oauth.go b/internal/mcp/oauth.go index d47a6ee..5465178 100644 --- a/internal/mcp/oauth.go +++ b/internal/mcp/oauth.go @@ -11,6 +11,7 @@ import ( "net/http" "os/exec" "runtime" + "strings" "time" "github.com/mark3labs/mcp-go/client" @@ -20,6 +21,12 @@ import ( const callbackPath = "/callback" +const refreshSkew = 5 * time.Minute + +type refreshTokenFunc func(context.Context, *FileTokenStore, *transport.Token) (*transport.Token, error) + +var refreshOAuthToken refreshTokenFunc = refreshTokenWithNotion + func GenerateCodeVerifier() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { @@ -201,11 +208,83 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error { } func RefreshToken(ctx context.Context, tokenStore *FileTokenStore) (*transport.Token, error) { - token, err := tokenStore.GetToken(ctx) + return refreshTokenLocked(ctx, tokenStore, true) +} + +func RefreshTokenIfNeeded(ctx context.Context, tokenStore *FileTokenStore) (*transport.Token, error) { + return refreshTokenLocked(ctx, tokenStore, false) +} + +func refreshTokenLocked(ctx context.Context, tokenStore *FileTokenStore, force bool) (*transport.Token, error) { + var refreshed *transport.Token + err := tokenStore.WithLock(ctx, func() error { + token, err := tokenStore.GetToken(ctx) + if err != nil { + return fmt.Errorf("get token: %w", err) + } + + if !force && tokenFresh(token) { + refreshed = token + return nil + } + + newToken, err := refreshOAuthToken(ctx, tokenStore, token) + if err != nil { + if isInvalidGrantError(err) { + latest, latestErr := tokenStore.GetToken(ctx) + if latestErr == nil && tokenFresh(latest) && !sameToken(latest, token) { + refreshed = latest + return nil + } + return fmt.Errorf("refresh token was rejected; browser login required: %w", err) + } + return err + } + + if err := tokenStore.SaveToken(ctx, newToken); err != nil { + return fmt.Errorf("save token: %w", err) + } + refreshed = newToken + return nil + }) if err != nil { - return nil, fmt.Errorf("get token: %w", err) + return nil, err + } + if refreshed == nil { + return nil, errors.New("refresh did not return a token") + } + return refreshed, nil +} + +func tokenFresh(token *transport.Token) bool { + return token != nil && + token.AccessToken != "" && + (token.ExpiresAt.IsZero() || token.ExpiresAt.After(time.Now().Add(refreshSkew))) +} + +func sameToken(a, b *transport.Token) bool { + if a == nil || b == nil { + return a == b + } + return a.AccessToken == b.AccessToken && + a.RefreshToken == b.RefreshToken && + a.ExpiresAt.Equal(b.ExpiresAt) +} + +func isInvalidGrantError(err error) bool { + if err == nil { + return false + } + + var oauthErr transport.OAuthError + if errors.As(err, &oauthErr) && oauthErr.ErrorCode == "invalid_grant" { + return true } + return strings.Contains(err.Error(), "invalid_grant") +} + +func refreshTokenWithNotion(ctx context.Context, tokenStore *FileTokenStore, token *transport.Token) (*transport.Token, error) { if token.RefreshToken == "" { return nil, errors.New("no refresh token available") } @@ -249,10 +328,6 @@ func RefreshToken(ctx context.Context, tokenStore *FileTokenStore) (*transport.T return nil, fmt.Errorf("refresh token: %w", err) } - if err := tokenStore.SaveToken(ctx, newToken); err != nil { - return nil, fmt.Errorf("save token: %w", err) - } - return newToken, nil } diff --git a/internal/mcp/oauth_test.go b/internal/mcp/oauth_test.go new file mode 100644 index 0000000..c76baa0 --- /dev/null +++ b/internal/mcp/oauth_test.go @@ -0,0 +1,239 @@ +package mcp + +import ( + "context" + "fmt" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client/transport" +) + +func isolateMCPConfig(t *testing.T) { + t.Helper() + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) +} + +func TestRefreshTokenSkipsRefreshWhenAnotherCallerAlreadyRefreshed(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + saveExpiredToken(t, store) + + oldRefresh := refreshOAuthToken + t.Cleanup(func() { + refreshOAuthToken = oldRefresh + }) + + var refreshCalls atomic.Int32 + refreshOAuthToken = func(ctx context.Context, tokenStore *FileTokenStore, token *transport.Token) (*transport.Token, error) { + refreshCalls.Add(1) + time.Sleep(50 * time.Millisecond) + return &transport.Token{ + AccessToken: "new-access", + TokenType: "bearer", + RefreshToken: "new-refresh", + ExpiresAt: time.Now().Add(time.Hour), + }, nil + } + + var wg sync.WaitGroup + errs := make(chan error, 2) + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + token, err := RefreshTokenIfNeeded(context.Background(), store) + if err != nil { + errs <- err + return + } + if token.AccessToken != "new-access" { + errs <- fmt.Errorf("access token = %q, want new-access", token.AccessToken) + } + }() + } + wg.Wait() + close(errs) + + for err := range errs { + if err != nil { + t.Fatal(err) + } + } + if got := refreshCalls.Load(); got != 1 { + t.Fatalf("refresh calls = %d, want 1", got) + } +} + +func TestRefreshTokenInvalidGrantUsesNewerSavedToken(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + saveExpiredToken(t, store) + + oldRefresh := refreshOAuthToken + t.Cleanup(func() { + refreshOAuthToken = oldRefresh + }) + + refreshOAuthToken = func(ctx context.Context, tokenStore *FileTokenStore, token *transport.Token) (*transport.Token, error) { + if err := tokenStore.SaveToken(ctx, &transport.Token{ + AccessToken: "winner-access", + TokenType: "bearer", + RefreshToken: "winner-refresh", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + return nil, fmt.Errorf("refresh token: %w", transport.OAuthError{ErrorCode: "invalid_grant"}) + } + + token, err := RefreshToken(context.Background(), store) + if err != nil { + t.Fatalf("RefreshToken: %v", err) + } + if token.AccessToken != "winner-access" { + t.Fatalf("access token = %q, want winner-access", token.AccessToken) + } +} + +func TestRefreshTokenForcesRefreshWhenTokenIsFresh(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + if err := store.SaveClientID(context.Background(), "client-123"); err != nil { + t.Fatalf("SaveClientID: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "fresh-access", + TokenType: "bearer", + RefreshToken: "fresh-refresh", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + oldRefresh := refreshOAuthToken + t.Cleanup(func() { + refreshOAuthToken = oldRefresh + }) + + var refreshCalls atomic.Int32 + refreshOAuthToken = func(context.Context, *FileTokenStore, *transport.Token) (*transport.Token, error) { + refreshCalls.Add(1) + return &transport.Token{ + AccessToken: "forced-access", + TokenType: "bearer", + RefreshToken: "forced-refresh", + ExpiresAt: time.Now().Add(time.Hour), + }, nil + } + + token, err := RefreshToken(context.Background(), store) + if err != nil { + t.Fatalf("RefreshToken: %v", err) + } + if token.AccessToken != "forced-access" { + t.Fatalf("access token = %q, want forced-access", token.AccessToken) + } + if got := refreshCalls.Load(); got != 1 { + t.Fatalf("refresh calls = %d, want 1", got) + } +} + +func TestRefreshTokenInvalidGrantRequiresLoginWhenNoNewerTokenExists(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + saveExpiredToken(t, store) + + oldRefresh := refreshOAuthToken + t.Cleanup(func() { + refreshOAuthToken = oldRefresh + }) + + refreshOAuthToken = func(context.Context, *FileTokenStore, *transport.Token) (*transport.Token, error) { + return nil, fmt.Errorf("refresh token: %w", transport.OAuthError{ErrorCode: "invalid_grant"}) + } + + _, err = RefreshToken(context.Background(), store) + if err == nil { + t.Fatal("RefreshToken returned nil error, want browser login required") + } + if got := err.Error(); !strings.Contains(got, "browser login required") { + t.Fatalf("error = %q, want browser login required", got) + } +} + +func TestRefreshTokenInvalidGrantDoesNotAcceptSameFreshToken(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + if err := store.SaveClientID(context.Background(), "client-123"); err != nil { + t.Fatalf("SaveClientID: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "fresh-access", + TokenType: "bearer", + RefreshToken: "fresh-refresh", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + oldRefresh := refreshOAuthToken + t.Cleanup(func() { + refreshOAuthToken = oldRefresh + }) + + refreshOAuthToken = func(context.Context, *FileTokenStore, *transport.Token) (*transport.Token, error) { + return nil, fmt.Errorf("refresh token: %w", transport.OAuthError{ErrorCode: "invalid_grant"}) + } + + _, err = RefreshToken(context.Background(), store) + if err == nil { + t.Fatal("RefreshToken returned nil error, want browser login required") + } + if got := err.Error(); !strings.Contains(got, "browser login required") { + t.Fatalf("error = %q, want browser login required", got) + } +} + +func saveExpiredToken(t *testing.T, store *FileTokenStore) { + t.Helper() + + if err := store.SaveClientID(context.Background(), "client-123"); err != nil { + t.Fatalf("SaveClientID: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "old-access", + TokenType: "bearer", + RefreshToken: "old-refresh", + ExpiresAt: time.Now().Add(-time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } +} diff --git a/internal/mcp/token_lock_unix.go b/internal/mcp/token_lock_unix.go new file mode 100644 index 0000000..b3055ab --- /dev/null +++ b/internal/mcp/token_lock_unix.go @@ -0,0 +1,30 @@ +//go:build !windows + +package mcp + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func acquireFileLock(path string) (*os.File, error) { + lockFile, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + if err := unix.Flock(int(lockFile.Fd()), unix.LOCK_EX); err != nil { + _ = lockFile.Close() + return nil, err + } + return lockFile, nil +} + +func releaseFileLock(lockFile *os.File) error { + err := unix.Flock(int(lockFile.Fd()), unix.LOCK_UN) + closeErr := lockFile.Close() + if err != nil { + return err + } + return closeErr +} diff --git a/internal/mcp/token_lock_windows.go b/internal/mcp/token_lock_windows.go new file mode 100644 index 0000000..2a62f16 --- /dev/null +++ b/internal/mcp/token_lock_windows.go @@ -0,0 +1,41 @@ +//go:build windows + +package mcp + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func acquireFileLock(path string) (*os.File, error) { + lockFile, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + + var overlapped windows.Overlapped + if err := windows.LockFileEx( + windows.Handle(lockFile.Fd()), + windows.LOCKFILE_EXCLUSIVE_LOCK, + 0, + 1, + 0, + &overlapped, + ); err != nil { + _ = lockFile.Close() + return nil, err + } + + return lockFile, nil +} + +func releaseFileLock(lockFile *os.File) error { + var overlapped windows.Overlapped + err := windows.UnlockFileEx(windows.Handle(lockFile.Fd()), 0, 1, 0, &overlapped) + closeErr := lockFile.Close() + if err != nil { + return err + } + return closeErr +} diff --git a/internal/mcp/token_store.go b/internal/mcp/token_store.go index 9b86cac..d2d6154 100644 --- a/internal/mcp/token_store.go +++ b/internal/mcp/token_store.go @@ -9,29 +9,28 @@ import ( "sync" "time" + "github.com/lox/notion-cli/internal/config" "github.com/mark3labs/mcp-go/client/transport" ) -const ( - configDir = ".config/notion-cli" - configFile = "token.json" -) - var ErrNoToken = errors.New("no token available") +var ( + refreshLocksMu sync.Mutex + refreshLocks = map[string]*sync.Mutex{} +) + type FileTokenStore struct { path string mu sync.RWMutex } -func NewFileTokenStore() (*FileTokenStore, error) { - homeDir, err := os.UserHomeDir() +func NewFileTokenStore(profile string) (*FileTokenStore, error) { + paths, err := config.PathsForProfile(profile) if err != nil { return nil, err } - - path := filepath.Join(homeDir, configDir, configFile) - return &FileTokenStore{path: path}, nil + return &FileTokenStore{path: paths.TokenPath}, nil } func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) { @@ -91,12 +90,7 @@ func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) ClientID: existing.ClientID, } - data, err := json.MarshalIndent(stored, "", " ") - if err != nil { - return err - } - - return os.WriteFile(s.path, data, 0600) + return s.writeStoredToken(ctx, stored) } func (s *FileTokenStore) Clear() error { @@ -113,6 +107,10 @@ func (s *FileTokenStore) Path() string { return s.path } +func (s *FileTokenStore) LockPath() string { + return s.path + ".lock" +} + type storedToken struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -167,10 +165,89 @@ func (s *FileTokenStore) SaveClientID(ctx context.Context, clientID string) erro stored.ClientID = clientID - data, err = json.MarshalIndent(stored, "", " ") + return s.writeStoredToken(ctx, stored) +} + +func (s *FileTokenStore) writeStoredToken(ctx context.Context, stored storedToken) error { + if err := ctx.Err(); err != nil { + return err + } + + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + data, err := json.MarshalIndent(stored, "", " ") + if err != nil { + return err + } + + tmp, err := os.CreateTemp(dir, ".token-*.json") + if err != nil { + return err + } + tmpPath := tmp.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if err := tmp.Chmod(0600); err != nil { + _ = tmp.Close() + return err + } + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + + if err := os.Rename(tmpPath, s.path); err != nil { + return err + } + cleanup = false + return nil +} + +func (s *FileTokenStore) WithLock(ctx context.Context, fn func() error) error { + if err := ctx.Err(); err != nil { + return err + } + + processLock := processRefreshLock(s.LockPath()) + processLock.Lock() + defer processLock.Unlock() + + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + lockFile, err := acquireFileLock(s.LockPath()) if err != nil { return err } + defer func() { _ = releaseFileLock(lockFile) }() - return os.WriteFile(s.path, data, 0600) + if err := ctx.Err(); err != nil { + return err + } + return fn() +} + +func processRefreshLock(path string) *sync.Mutex { + refreshLocksMu.Lock() + defer refreshLocksMu.Unlock() + + lock, ok := refreshLocks[path] + if !ok { + lock = &sync.Mutex{} + refreshLocks[path] = lock + } + return lock } diff --git a/internal/mcp/token_store_test.go b/internal/mcp/token_store_test.go new file mode 100644 index 0000000..9004e19 --- /dev/null +++ b/internal/mcp/token_store_test.go @@ -0,0 +1,73 @@ +package mcp + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client/transport" +) + +func TestNewFileTokenStoreUsesProfilePath(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + + if got := filepath.Base(store.Path()); got != "token.json" { + t.Fatalf("token filename = %q, want token.json", got) + } + if !strings.Contains(store.Path(), filepath.Join("profiles", "work")) { + t.Fatalf("token path = %q, want profiles/work segment", store.Path()) + } +} + +func TestSaveTokenWritesAtomicallyAndPreservesClientID(t *testing.T) { + isolateMCPConfig(t) + + store, err := NewFileTokenStore("work") + if err != nil { + t.Fatalf("NewFileTokenStore: %v", err) + } + + if err := store.SaveClientID(context.Background(), "client-123"); err != nil { + t.Fatalf("SaveClientID: %v", err) + } + if err := store.SaveToken(context.Background(), &transport.Token{ + AccessToken: "access-123", + TokenType: "bearer", + RefreshToken: "refresh-123", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SaveToken: %v", err) + } + + info, err := os.Stat(store.Path()) + if err != nil { + t.Fatalf("stat token file: %v", err) + } + if got := info.Mode().Perm(); got != 0600 { + t.Fatalf("token file mode = %o, want 0600", got) + } + + data, err := os.ReadFile(store.Path()) + if err != nil { + t.Fatalf("read token file: %v", err) + } + var stored storedToken + if err := json.Unmarshal(data, &stored); err != nil { + t.Fatalf("unmarshal token: %v", err) + } + if stored.ClientID != "client-123" { + t.Fatalf("client ID = %q, want client-123", stored.ClientID) + } + if stored.RefreshToken != "refresh-123" { + t.Fatalf("refresh token = %q, want refresh-123", stored.RefreshToken) + } +} diff --git a/main.go b/main.go index a7a7232..9ef9ab7 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "github.com/alecthomas/kong" "github.com/lox/notion-cli/cmd" "github.com/lox/notion-cli/internal/cli" + "github.com/lox/notion-cli/internal/config" ) var version = "dev" @@ -23,8 +24,12 @@ func main() { kong.UsageOnError(), kong.Vars{"version": version}, ) + profile, err := config.ResolveSelectedProfile(c.Profile) + ctx.FatalIfErrorf(err) cli.SetAccessToken(c.Token) - err := ctx.Run(&cmd.Context{ + cli.SetProfile(profile) + err = ctx.Run(&cmd.Context{ + Profile: profile, Token: c.Token, APIToken: c.APIToken, APIBaseURL: c.APIBaseURL, diff --git a/skills/notion/SKILL.md b/skills/notion/SKILL.md index 6502630..6fbb5f1 100644 --- a/skills/notion/SKILL.md +++ b/skills/notion/SKILL.md @@ -30,8 +30,8 @@ The CLI uses OAuth authentication for MCP-backed commands. On first use, it open ```bash notion-cli auth login # Authenticate with Notion -notion-cli auth status # Check authentication status -notion-cli auth refresh # Refresh token if status shows expired token +notion-cli auth status # Show active profile and OAuth state (diagnostic) +notion-cli auth refresh # Force-refresh; commands auto-refresh on use, so rarely needed notion-cli auth logout # Clear credentials ``` @@ -48,6 +48,19 @@ notion-cli auth api unset For CI/headless environments, set `NOTION_API_TOKEN`. +### Multiple accounts + +Every command accepts `--profile ` (or `NOTION_CLI_PROFILE`) to target a specific Notion account. Named profiles keep credentials isolated under `~/.config/notion-cli/profiles//`; the implicit default profile uses the existing top-level paths. + +```bash +notion-cli auth login --profile work +notion-cli page list --profile work +export NOTION_CLI_PROFILE=work # pin for the shell session +notion-cli auth use work # make work the default profile +``` + +Resolution precedence: `--profile` > `NOTION_CLI_PROFILE` > `notion-cli auth use ` > implicit default profile. + ## Available Commands ``` @@ -191,6 +204,6 @@ notion-cli search "api" --json | jq '.[] | .title' 6. **Inline comments by default** - `page view` includes open page comments and inline block discussions unless `--no-comments` is set 7. **Raw output** - Use `--raw` with `page view` to see the original Notion markup 8. **JSON for parsing** - Use `--json` when you need to extract specific fields, including the `Comments` array from `page view` -9. **Auth preflight** - Run `notion-cli auth status --json` before a multi-step workflow and refresh/login if needed +9. **No auth preflight** - Just run the command; the CLI auto-refreshes tokens on use. `notion-cli auth status` and `notion-cli auth list` are diagnostic surfaces, not health gates - do not poll them as a sanity check before each call. Only run `notion-cli auth login` if a real command returns an authentication error. 10. **API fallback preflight** - Run `notion-cli auth api verify` before workflows that need local image upload 11. **Error handling** - If a targeted `page edit` call fails, rerun with `--replace` as a safe fallback