diff --git a/.gitignore b/.gitignore index 1dfa72ce..0f686a85 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,6 @@ overleaf.kubeconfig # coverage report coverage.out coverage.html + +# claude code +CLAUDE.md \ No newline at end of file diff --git a/helm-chart/values.yaml b/helm-chart/values.yaml index e8e2c200..86546e8e 100644 --- a/helm-chart/values.yaml +++ b/helm-chart/values.yaml @@ -14,7 +14,7 @@ paperdebuggerMcpServer: image: ghcr.io/paperdebugger/paperdebugger-mcp-server:main-14409c5 paperdebuggerXtraMcpServer: - image: ghcr.io/paperdebugger/xtragpt-mcp-server:sha-95f70ce + image: ghcr.io/paperdebugger/xtragpt-mcp-server:sha-510bc96 mongo: in_cluster: true diff --git a/internal/api/chat/get_citation_keys.go b/internal/api/chat/get_citation_keys.go new file mode 100644 index 00000000..63fb985c --- /dev/null +++ b/internal/api/chat/get_citation_keys.go @@ -0,0 +1,51 @@ +package chat + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + "paperdebugger/internal/libs/shared" + "paperdebugger/internal/models" + chatv2 "paperdebugger/pkg/gen/api/chat/v2" +) + +func (s *ChatServerV2) GetCitationKeys( + ctx context.Context, + req *chatv2.GetCitationKeysRequest, +) (*chatv2.GetCitationKeysResponse, error) { + if req.GetSentence() == "" { + return nil, shared.ErrBadRequest("sentence is required") + } + if req.GetProjectId() == "" { + return nil, shared.ErrBadRequest("project_id is required") + } + + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + settings, err := s.userService.GetUserSettings(ctx, actor.ID) + if err != nil { + return nil, err + } + + llmProvider := &models.LLMProviderConfig{ + APIKey: settings.OpenAIAPIKey, + } + + citationKeys, err := s.aiClientV2.GetCitationKeys( + ctx, + req.GetSentence(), + actor.ID, + req.GetProjectId(), + llmProvider, + ) + if err != nil { + return nil, err + } + + return &chatv2.GetCitationKeysResponse{ + CitationKeys: citationKeys, + }, nil +} diff --git a/internal/api/mapper/user.go b/internal/api/mapper/user.go index a7fa8538..78c98ef3 100644 --- a/internal/api/mapper/user.go +++ b/internal/api/mapper/user.go @@ -9,7 +9,7 @@ func MapProtoSettingsToModel(settings *userv1.Settings) *models.Settings { return &models.Settings{ ShowShortcutsAfterSelection: settings.ShowShortcutsAfterSelection, FullWidthPaperDebuggerButton: settings.FullWidthPaperDebuggerButton, - EnableCompletion: settings.EnableCompletion, + EnableCitationSuggestion: settings.EnableCitationSuggestion, FullDocumentRag: settings.FullDocumentRag, ShowedOnboarding: settings.ShowedOnboarding, OpenAIAPIKey: settings.OpenaiApiKey, @@ -20,7 +20,7 @@ func MapModelSettingsToProto(settings *models.Settings) *userv1.Settings { return &userv1.Settings{ ShowShortcutsAfterSelection: settings.ShowShortcutsAfterSelection, FullWidthPaperDebuggerButton: settings.FullWidthPaperDebuggerButton, - EnableCompletion: settings.EnableCompletion, + EnableCitationSuggestion: settings.EnableCitationSuggestion, FullDocumentRag: settings.FullDocumentRag, ShowedOnboarding: settings.ShowedOnboarding, OpenaiApiKey: settings.OpenAIAPIKey, diff --git a/internal/models/user.go b/internal/models/user.go index c9bd1509..22e03ad2 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -5,7 +5,7 @@ import "go.mongodb.org/mongo-driver/v2/bson" type Settings struct { ShowShortcutsAfterSelection bool `bson:"show_shortcuts_after_selection"` FullWidthPaperDebuggerButton bool `bson:"full_width_paper_debugger_button"` - EnableCompletion bool `bson:"enable_completion"` + EnableCitationSuggestion bool `bson:"enable_citation_suggestion"` FullDocumentRag bool `bson:"full_document_rag"` ShowedOnboarding bool `bson:"showed_onboarding"` OpenAIAPIKey string `bson:"openai_api_key"` diff --git a/internal/services/toolkit/client/get_citation_keys.go b/internal/services/toolkit/client/get_citation_keys.go new file mode 100644 index 00000000..1995d590 --- /dev/null +++ b/internal/services/toolkit/client/get_citation_keys.go @@ -0,0 +1,272 @@ +package client + +// TODO: This file should not place in the client package. +import ( + "context" + "fmt" + "paperdebugger/internal/models" + "paperdebugger/internal/services/toolkit/tools/xtramcp" + "regexp" + "strings" + + "github.com/openai/openai-go/v3" + "go.mongodb.org/mongo-driver/v2/bson" +) + +var ( + // Regex patterns compiled once + titleFieldRe = regexp.MustCompile(`(?i)title\s*=\s*`) // matches "title = " prefix + entryStartRe = regexp.MustCompile(`(?i)^\s*@(\w+)\s*\{`) // eg. @article{ + stringEntryRe = regexp.MustCompile(`(?i)^\s*@String\s*\{`) // eg. @String{ + multiSpaceRe = regexp.MustCompile(` {2,}`) + + // Fields to exclude from bibliography (not useful for citation matching) + excludedFields = []string{ + "address", "institution", "pages", "eprint", "primaryclass", "volume", "number", + "edition", "numpages", "articleno", "publisher", "editor", "doi", "url", "acmid", + "issn", "archivePrefix", "year", "month", "day", "eid", "lastaccessed", "organization", + "school", "isbn", "mrclass", "mrnumber", "mrreviewer", "type", "order_no", "location", + "howpublished", "distincturl", "issue_date", "archived", "series", "source", + } + excludeFieldRe = regexp.MustCompile(`(?i)^\s*(` + strings.Join(excludedFields, "|") + `)\s*=`) +) + +// braceBalance returns the net brace count (opens - closes) in a string. +func braceBalance(s string) int { + return strings.Count(s, "{") - strings.Count(s, "}") +} + +// isQuoteUnclosed returns true if the string has an odd number of double quotes. +func isQuoteUnclosed(s string) bool { + return strings.Count(s, `"`)%2 == 1 +} + +// extractBalancedValue extracts a BibTeX field value (braced or quoted) starting at pos. +// It is needed for (1) getting full title (for abstract lookup) and (2) skipping excluded +// fields that may span multiple lines. +// Returns the extracted content and end position, or empty string and -1 if not found. +func extractBalancedValue(s string, pos int) (string, int) { + // Skip whitespace + for pos < len(s) && (s[pos] == ' ' || s[pos] == '\t' || s[pos] == '\n' || s[pos] == '\r') { + pos++ + } + if pos >= len(s) { + return "", -1 + } + + switch s[pos] { + case '{': + depth := 0 + start := pos + 1 + for i := pos; i < len(s); i++ { + switch s[i] { + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + return s[start:i], i + 1 + } + } + } + case '"': + start := pos + 1 + for i := start; i < len(s); i++ { + if s[i] == '"' { + return s[start:i], i + 1 + } + } + } + return "", -1 +} + +// extractTitle extracts the title from a BibTeX entry string. +// It handles nested braces like title = {A Study of {COVID-19}}. +func extractTitle(entry string) string { + loc := titleFieldRe.FindStringIndex(entry) + if loc == nil { + return "" + } + content, _ := extractBalancedValue(entry, loc[1]) + return strings.TrimSpace(content) +} + +// parseBibFile extracts bibliography entries from a .bib file's lines, +// filtering out @String macros, comments, and excluded fields (url, doi, etc.). +func parseBibFile(lines []string) []string { + var entries []string + var currentEntry []string + + // It handles multi-line field values by tracking brace/quote balance: + // - skipBraces > 0: currently skipping a {bracketed} value, wait until balanced + // - skipQuotes = true: currently skipping a "quoted" value, wait for closing quote + + var entryDepth int // brace depth for current entry (0 = entry complete) + var skipBraces int // > 0 means we're skipping lines until braces balance + var skipQuotes bool // true means we're skipping lines until closing quote + + for _, line := range lines { + // Skip empty lines and comments + if trimmed := strings.TrimSpace(line); trimmed == "" || strings.HasPrefix(trimmed, "%") { + continue + } + + // If skipping a multi-line {bracketed} field value, keep skipping until balanced + if skipBraces > 0 { + skipBraces += braceBalance(line) + continue + } + + // If skipping a multi-line "quoted" field value, keep skipping until closing quote + if skipQuotes { + if isQuoteUnclosed(line) { // odd quote count = found closing quote + skipQuotes = false + } + continue + } + + // Skip @String{...} macro definitions + if stringEntryRe.MatchString(line) { + skipBraces = braceBalance(line) + continue + } + + // Skip excluded fields (url, doi, pages, etc.) - may span multiple lines + if excludeFieldRe.MatchString(line) { + if strings.Contains(line, "={") || strings.Contains(line, "= {") { + skipBraces = braceBalance(line) + } else if strings.Contains(line, `="`) || strings.Contains(line, `= "`) { + skipQuotes = isQuoteUnclosed(line) + } + continue + } + + // Start of new entry: @article{key, or @book{key, etc. + if entryStartRe.MatchString(line) { + if len(currentEntry) > 0 { + entries = append(entries, strings.Join(currentEntry, "\n")) + } + currentEntry = []string{line} + entryDepth = braceBalance(line) + continue + } + + // Continue building current entry + if len(currentEntry) > 0 { + currentEntry = append(currentEntry, line) + entryDepth += braceBalance(line) + if entryDepth <= 0 { // entry complete when braces balance + entries = append(entries, strings.Join(currentEntry, "\n")) + currentEntry = nil + } + } + } + + // Last entry if file doesn't end with balanced braces + if len(currentEntry) > 0 { + entries = append(entries, strings.Join(currentEntry, "\n")) + } + return entries +} + +// fetchAbstracts enriches entries with abstracts from XtraMCP using batch API. +func (a *AIClientV2) fetchAbstracts(ctx context.Context, entries []string) []string { + // Extract titles + var titles []string + for _, entry := range entries { + if title := extractTitle(entry); title != "" { + titles = append(titles, title) + } + } + + // Fetch abstracts and build lookup map + abstracts := make(map[string]string) + svc := xtramcp.NewXtraMCPServices(a.cfg.XtraMCPURI) + resp, err := svc.GetPaperAbstracts(ctx, titles) + if err == nil && resp.Success { + for _, r := range resp.Results { + if r.Found { + abstracts[r.Title] = r.Abstract + } + } + } + + // Enrich entries + result := make([]string, len(entries)) + for i, entry := range entries { + if abstract, ok := abstracts[extractTitle(entry)]; ok && abstract != "" { + if pos := strings.LastIndex(entry, "}"); pos > 0 { + result[i] = entry[:pos] + fmt.Sprintf(",\n abstract = {%s}\n}", abstract) + continue + } + } + result[i] = entry + } + return result +} + +// GetBibliographyForCitation extracts bibliography content from a project's .bib files. +// It excludes non-essential fields to save tokens and fetches abstracts from XtraMCP. +func (a *AIClientV2) GetBibliographyForCitation(ctx context.Context, userId bson.ObjectID, projectId string) (string, error) { + project, err := a.projectService.GetProject(ctx, userId, projectId) + if err != nil { + return "", err + } + + // Parse all .bib files + var entries []string + for _, doc := range project.Docs { + if strings.HasSuffix(doc.Filepath, ".bib") { + entries = append(entries, parseBibFile(doc.Lines)...) + } + } + + // Enrich with abstracts + entries = a.fetchAbstracts(ctx, entries) + + // Join and normalize + bibliography := strings.Join(entries, "\n") + return multiSpaceRe.ReplaceAllString(bibliography, " "), nil +} + +func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userId bson.ObjectID, projectId string, llmProvider *models.LLMProviderConfig) ([]string, error) { + bibliography, err := a.GetBibliographyForCitation(ctx, userId, projectId) + + if err != nil { + return nil, err + } + + emptyCitation := "none" + + // Bibliography is placed at the start of the prompt to leverage prompt caching + message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation) + + _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ + openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), + openai.UserMessage(message), + }, llmProvider) + + if err != nil { + return nil, err + } + + if len(resp) == 0 { + return []string{}, nil + } + + citationKeysStr := strings.TrimSpace(resp[0].Payload.GetAssistant().GetContent()) + + if citationKeysStr == "" || citationKeysStr == emptyCitation { + return []string{}, nil + } + + // Parse comma-separated keys + var result []string + for _, key := range strings.Split(citationKeysStr, ",") { + if trimmed := strings.TrimSpace(key); trimmed != "" { + result = append(result, trimmed) + } + } + + return result, nil +} diff --git a/internal/services/toolkit/client/get_citation_keys_test.go b/internal/services/toolkit/client/get_citation_keys_test.go new file mode 100644 index 00000000..4d2a857d --- /dev/null +++ b/internal/services/toolkit/client/get_citation_keys_test.go @@ -0,0 +1,763 @@ +package client_test + +import ( + "context" + "os" + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + "paperdebugger/internal/services" + "paperdebugger/internal/services/toolkit/client" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/v2/bson" +) + +// setupTestClient creates an AIClientV2 for testing with MongoDB +func setupTestClient(t *testing.T) (*client.AIClientV2, *services.ProjectService) { + os.Setenv("PD_MONGO_URI", "mongodb://localhost:27017") + dbInstance, err := db.NewDB(cfg.GetCfg(), logger.GetLogger()) + if err != nil { + t.Skipf("MongoDB not available: %v", err) + } + + projectService := services.NewProjectService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + aiClient := client.NewAIClientV2( + dbInstance, + &services.ReverseCommentService{}, + projectService, + cfg.GetCfg(), + logger.GetLogger(), + ) + return aiClient, projectService +} + +// createTestProject creates a project with the given bib content for testing +func createTestProject(t *testing.T, projectService *services.ProjectService, userId bson.ObjectID, projectId string, bibContent []string) { + ctx := context.Background() + project := &models.Project{ + Docs: []models.ProjectDoc{ + { + ID: "bib-doc", + Version: 1, + Filepath: "references.bib", + Lines: bibContent, + }, + }, + } + _, err := projectService.UpsertProject(ctx, userId, projectId, project) + assert.NoError(t, err) +} + +func TestGetBibliographyForCitation_FieldExclusion(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-field-exclusion-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{smith2020,", + " author = {John Smith},", + " title = {A Great Paper},", + " journal = {Nature},", + " url = {https://example.com/paper},", + " doi = {10.1234/example},", + " pages = {1-10},", + " volume = {5},", + " publisher = {Nature Publishing},", + " year = {2020},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Essential fields should be kept + assert.Contains(t, result, "author") + assert.Contains(t, result, "John Smith") + assert.Contains(t, result, "title") + assert.Contains(t, result, "A Great Paper") + assert.Contains(t, result, "journal") + assert.Contains(t, result, "Nature") + + // Non-essential fields should be excluded + assert.NotContains(t, result, "url") + assert.NotContains(t, result, "https://example.com") + assert.NotContains(t, result, "doi") + assert.NotContains(t, result, "10.1234") + assert.NotContains(t, result, "pages") + assert.NotContains(t, result, "1-10") + assert.NotContains(t, result, "volume") + assert.NotContains(t, result, "publisher") + assert.NotContains(t, result, "year") +} + +func TestGetBibliographyForCitation_MultiLineFieldExclusion(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-multiline-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{multiline2023,", + " author = {Test Author},", + " url = {https://example.com/", + " very/long/path/to/paper},", + " title = {Test Paper},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should keep author and title + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + + // Should exclude multi-line url field completely + assert.NotContains(t, result, "url") + assert.NotContains(t, result, "very/long/path") +} + +func TestGetBibliographyForCitation_StringEntryExclusion(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-string-entry-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@String{nature = {Nature Publishing}}", + "@String{longjournal = {Journal of Very", + " Long Names and Things}}", + "@article{test2023,", + " author = {Test Author},", + " title = {Test Title},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should keep the article entry + assert.Contains(t, result, "@article") + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + + // Should exclude @String entries + assert.NotContains(t, result, "@String") + assert.NotContains(t, result, "Nature Publishing") + assert.NotContains(t, result, "Long Names") +} + +func TestGetBibliographyForCitation_CommentsAndEmptyLines(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-comments-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "% This is a comment that should be excluded", + "@article{commented2023,", + "", + " author = {Test Author},", + " % Another comment", + " ", + " title = {Test Title},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should keep the article content + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + + // Should exclude comments + assert.NotContains(t, result, "This is a comment") + assert.NotContains(t, result, "Another comment") +} + +func TestGetBibliographyForCitation_CaseInsensitiveFieldMatching(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-case-insensitive-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{casetest,", + " AUTHOR = {Case Author},", + " URL = {https://example.com},", + " Title = {Case Title},", + " DOI = {10.1234/test},", + " Pages = {1-10},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should keep essential fields regardless of case + assert.Contains(t, result, "AUTHOR") + assert.Contains(t, result, "Title") + + // Should exclude non-essential fields regardless of case + assert.NotContains(t, result, "URL") + assert.NotContains(t, result, "DOI") + assert.NotContains(t, result, "Pages") +} + +func TestGetBibliographyForCitation_OnlyBibFiles(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-only-bib-" + bson.NewObjectID().Hex() + + project := &models.Project{ + Docs: []models.ProjectDoc{ + { + ID: "tex-doc", + Version: 1, + Filepath: "main.tex", + Lines: []string{"\\documentclass{article}", "\\begin{document}", "Hello"}, + }, + { + ID: "bib-doc", + Version: 1, + Filepath: "refs.bib", + Lines: []string{"@article{test,", " author = {Bib Author},", "}"}, + }, + { + ID: "txt-doc", + Version: 1, + Filepath: "notes.txt", + Lines: []string{"Some notes here"}, + }, + }, + } + _, err := projectService.UpsertProject(ctx, userId, projectId, project) + assert.NoError(t, err) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should only contain bib file content + assert.Contains(t, result, "Bib Author") + + // Should not contain tex or txt content + assert.NotContains(t, result, "documentclass") + assert.NotContains(t, result, "Some notes") +} + +func TestGetBibliographyForCitation_QuotedFieldValues(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-quoted-" + bson.NewObjectID().Hex() + + bibContent := []string{ + `@article{quoted2023,`, + ` author = "Alice Author",`, + ` url = "https://example.com",`, + ` title = "Quoted Title",`, + `}`, + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should keep author and title + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + + // Should exclude url even with quoted value + assert.NotContains(t, result, "url") + assert.NotContains(t, result, "https://example.com") +} + +func TestGetBibliographyForCitation_NoBibFiles(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-no-bib-" + bson.NewObjectID().Hex() + + project := &models.Project{ + Docs: []models.ProjectDoc{ + { + ID: "tex-doc", + Version: 1, + Filepath: "main.tex", + Lines: []string{"\\documentclass{article}", "\\begin{document}", "Hello", "\\end{document}"}, + }, + }, + } + _, err := projectService.UpsertProject(ctx, userId, projectId, project) + assert.NoError(t, err) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + assert.Empty(t, result) +} + +func TestGetBibliographyForCitation_EmptyBibFile(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-empty-bib-" + bson.NewObjectID().Hex() + + createTestProject(t, projectService, userId, projectId, []string{}) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + assert.Empty(t, result) +} + +func TestGetBibliographyForCitation_NestedBraces(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-nested-braces-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{nested2023,", + " author = {John {van} Smith},", + " title = {A {GPU}-Based Approach to {NLP}},", + " journal = {Journal of {AI} Research},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should preserve nested braces in kept fields + assert.Contains(t, result, "author") + assert.Contains(t, result, "{van}") + assert.Contains(t, result, "title") + assert.Contains(t, result, "{GPU}") + assert.Contains(t, result, "{NLP}") + assert.Contains(t, result, "journal") + assert.Contains(t, result, "{AI}") +} + +func TestGetBibliographyForCitation_DifferentEntryTypes(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-entry-types-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{article2023,", + " author = {Article Author},", + " title = {Article Title},", + "}", + "@book{book2023,", + " author = {Book Author},", + " title = {Book Title},", + "}", + "@inproceedings{inproc2023,", + " author = {Conference Author},", + " title = {Conference Paper},", + " booktitle = {ICML 2023},", + "}", + "@misc{misc2023,", + " author = {Misc Author},", + " title = {Misc Title},", + " note = {Some note},", + "}", + "@phdthesis{thesis2023,", + " author = {PhD Author},", + " title = {Thesis Title},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should include all entry types + assert.Contains(t, result, "@article") + assert.Contains(t, result, "@book") + assert.Contains(t, result, "@inproceedings") + assert.Contains(t, result, "@misc") + assert.Contains(t, result, "@phdthesis") + + // Should preserve booktitle (not in excluded list) + assert.Contains(t, result, "booktitle") + assert.Contains(t, result, "ICML 2023") + + // Should preserve note (not in excluded list) + assert.Contains(t, result, "note") +} + +func TestGetBibliographyForCitation_MultiLineQuotedValues(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-multiline-quoted-" + bson.NewObjectID().Hex() + + bibContent := []string{ + `@article{quoted2023,`, + ` author = "Test Author",`, + ` url = "https://example.com/very/`, + ` long/path/to/paper",`, + ` title = "Test Title",`, + `}`, + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should keep author and title + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + + // Should exclude multi-line quoted url field + assert.NotContains(t, result, "url") + assert.NotContains(t, result, "long/path") +} + +func TestGetBibliographyForCitation_MalformedEntry(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-malformed-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{valid2023,", + " author = {Valid Author},", + " title = {Valid Title},", + "}", + "@article{malformed2023,", + " author = {Malformed Author},", + " title = {Missing closing brace", + "@article{aftermalformed,", + " author = {After Author},", + " title = {After Title},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Should at least parse the valid entry + assert.Contains(t, result, "Valid Author") + assert.Contains(t, result, "Valid Title") +} + +func TestGetBibliographyForCitation_TitleMultilineBraces(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-title-multiline-braces-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{multiline2023,", + " author = {Test Author},", + " title = {A Very Long Title That Spans", + " Multiple Lines in the Bib File},", + " journal = {Test Journal},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Title should be preserved even when spanning multiple lines + assert.Contains(t, result, "title") + assert.Contains(t, result, "A Very Long Title") + assert.Contains(t, result, "Multiple Lines") +} + +func TestGetBibliographyForCitation_TitleMultilineQuotes(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-title-multiline-quotes-" + bson.NewObjectID().Hex() + + bibContent := []string{ + `@article{quotedtitle2023,`, + ` author = "Test Author",`, + ` title = "A Quoted Title That Spans`, + ` Multiple Lines",`, + ` journal = "Test Journal",`, + `}`, + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Title should be preserved even with multiline quotes + assert.Contains(t, result, "title") + assert.Contains(t, result, "A Quoted Title") + assert.Contains(t, result, "Multiple Lines") +} + +func TestGetBibliographyForCitation_ExcludedFieldNestedBraces(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-excluded-nested-braces-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{nested2023,", + " author = {Test Author},", + " title = {Test Title},", + " url = {https://example.com/{version}/path/{id}},", + " doi = {10.1234/{special}/value},", + " journal = {Test Journal},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Essential fields should be kept + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + assert.Contains(t, result, "journal") + + // Excluded fields with nested braces should be completely removed + assert.NotContains(t, result, "url") + assert.NotContains(t, result, "{version}") + assert.NotContains(t, result, "doi") + assert.NotContains(t, result, "{special}") +} + +func TestGetBibliographyForCitation_TitleMultilineNestedBraces(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-title-multiline-nested-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{multilinested2023,", + " author = {Test Author},", + " title = {A Study of {COVID-19} and Its", + " Impact on {Machine Learning}", + " Applications},", + " journal = {Test Journal},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Title with multiline nested braces should be preserved + assert.Contains(t, result, "title") + assert.Contains(t, result, "{COVID-19}") + assert.Contains(t, result, "{Machine Learning}") + assert.Contains(t, result, "Applications") +} + +func TestGetBibliographyForCitation_ExcludedFieldMultilineNestedBraces(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-excluded-multiline-nested-" + bson.NewObjectID().Hex() + + bibContent := []string{ + "@article{exclmultnest2023,", + " author = {Test Author},", + " title = {Test Title},", + " url = {https://example.com/{api}/{v2}", + " /resources/{id}/data},", + " journal = {Test Journal},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // Essential fields should be kept + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + assert.Contains(t, result, "journal") + + // Excluded field with multiline nested braces should be completely removed + assert.NotContains(t, result, "url") + assert.NotContains(t, result, "{api}") + assert.NotContains(t, result, "{v2}") + assert.NotContains(t, result, "/resources/") +} + +func TestGetBibliographyForCitation_EssentialFieldsPreserved(t *testing.T) { + aiClient, projectService := setupTestClient(t) + ctx := context.Background() + userId := bson.NewObjectID() + projectId := "test-essential-fields-" + bson.NewObjectID().Hex() + + // Test that important fields for citation matching are preserved + bibContent := []string{ + "@article{essential2023,", + " author = {Essential Author},", + " title = {Essential Title},", + " journal = {Essential Journal},", + " booktitle = {Essential Booktitle},", + " note = {Essential Note},", + " keywords = {machine learning, AI},", + " abstract = {This is the abstract.},", + "}", + } + + createTestProject(t, projectService, userId, projectId, bibContent) + + result, err := aiClient.GetBibliographyForCitation(ctx, userId, projectId) + assert.NoError(t, err) + + // These fields should be preserved as they're useful for citation matching + assert.Contains(t, result, "author") + assert.Contains(t, result, "title") + assert.Contains(t, result, "journal") + assert.Contains(t, result, "booktitle") + assert.Contains(t, result, "note") + assert.Contains(t, result, "keywords") + assert.Contains(t, result, "abstract") +} + +// TestCitationKeysParsing tests the expected parsing behavior for citation key responses. +// This verifies the parsing logic that GetCitationKeys uses internally. +func TestCitationKeysParsing(t *testing.T) { + // Helper that mimics the parsing logic in GetCitationKeys + parseCitationKeys := func(response string) []string { + emptyCitation := "none" + citationKeysStr := strings.TrimSpace(response) + + if citationKeysStr == "" || citationKeysStr == emptyCitation { + return []string{} + } + + keys := strings.Split(citationKeysStr, ",") + result := make([]string, 0, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result + } + + tests := []struct { + name string + response string + expected []string + }{ + { + name: "single key", + response: "smith2020", + expected: []string{"smith2020"}, + }, + { + name: "multiple keys comma separated", + response: "smith2020,jones2021,doe2022", + expected: []string{"smith2020", "jones2021", "doe2022"}, + }, + { + name: "keys with spaces around commas", + response: "smith2020, jones2021, doe2022", + expected: []string{"smith2020", "jones2021", "doe2022"}, + }, + { + name: "empty response", + response: "", + expected: []string{}, + }, + { + name: "none response", + response: "none", + expected: []string{}, + }, + { + name: "whitespace only", + response: " ", + expected: []string{}, + }, + { + name: "response with leading/trailing whitespace", + response: " smith2020,jones2021 ", + expected: []string{"smith2020", "jones2021"}, + }, + { + name: "handles empty segments from trailing comma", + response: "smith2020,jones2021,", + expected: []string{"smith2020", "jones2021"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseCitationKeys(tt.response) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestCitationPromptFormat verifies the expected prompt structure. +// This ensures the prompt format meets requirements (bibliography first for caching, etc.) +func TestCitationPromptFormat(t *testing.T) { + // Helper that mimics the prompt building in GetCitationKeys + buildPrompt := func(bibliography, sentence string) string { + emptyCitation := "none" + return "Bibliography: " + bibliography + "\nSentence: " + sentence + "\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '" + emptyCitation + "'." + } + + t.Run("bibliography comes first for prompt caching", func(t *testing.T) { + prompt := buildPrompt("@article{test}", "Test sentence") + assert.True(t, strings.HasPrefix(prompt, "Bibliography:"), + "prompt should start with Bibliography for prompt caching") + }) + + t.Run("contains bibliography content", func(t *testing.T) { + prompt := buildPrompt("@article{smith2020, author={Smith}}", "Test sentence") + assert.Contains(t, prompt, "@article{smith2020") + assert.Contains(t, prompt, "author={Smith}") + }) + + t.Run("contains sentence", func(t *testing.T) { + prompt := buildPrompt("@article{test}", "Machine learning is transforming research.") + assert.Contains(t, prompt, "Machine learning is transforming research.") + }) + + t.Run("includes empty citation marker", func(t *testing.T) { + prompt := buildPrompt("", "Test") + assert.Contains(t, prompt, "none") + }) + + t.Run("includes format instructions", func(t *testing.T) { + prompt := buildPrompt("", "Test") + assert.Contains(t, prompt, "comma") + assert.Contains(t, prompt, "key1,key2") + }) +} diff --git a/internal/services/toolkit/tools/xtramcp/services.go b/internal/services/toolkit/tools/xtramcp/services.go new file mode 100644 index 00000000..880d3693 --- /dev/null +++ b/internal/services/toolkit/tools/xtramcp/services.go @@ -0,0 +1,81 @@ +package xtramcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +// PaperAbstractResponse represents the response from XtraMCP paper-abstract REST API +type PaperAbstractResponse struct { + Success bool `json:"success"` + Found bool `json:"found"` + Title string `json:"title"` + Abstract string `json:"abstract"` +} + +// PaperAbstractsRequest represents the request body for batch paper abstracts API +type PaperAbstractsRequest struct { + Titles []string `json:"titles"` +} + +// PaperAbstractsResponse represents the response from batch paper abstracts API +type PaperAbstractsResponse struct { + Success bool `json:"success"` + Results []PaperAbstractResponse `json:"results"` +} + +// XtraMCPServices provides access to XtraMCP REST APIs that don't require MCP session +type XtraMCPServices struct { + baseURL string + client *http.Client +} + +// NewXtraMCPServices creates a new XtraMCP services client +func NewXtraMCPServices(baseURL string) *XtraMCPServices { + return &XtraMCPServices{ + baseURL: baseURL, + client: &http.Client{}, + } +} + +// GetPaperAbstracts fetches abstracts for multiple papers in a single request +func (s *XtraMCPServices) GetPaperAbstracts(ctx context.Context, titles []string) (*PaperAbstractsResponse, error) { + if len(titles) == 0 { + return &PaperAbstractsResponse{Success: true, Results: []PaperAbstractResponse{}}, nil + } + + baseURL := strings.TrimSuffix(s.baseURL, "/mcp") + endpoint := fmt.Sprintf("%s/api/paper-abstracts", baseURL) + + reqBody, err := json.Marshal(PaperAbstractsRequest{Titles: titles}) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var result PaperAbstractsResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &result, nil +} diff --git a/internal/services/user.go b/internal/services/user.go index 6734fd30..d8c520be 100644 --- a/internal/services/user.go +++ b/internal/services/user.go @@ -122,7 +122,7 @@ func (s *UserService) GetDefaultSettings() models.Settings { return models.Settings{ ShowShortcutsAfterSelection: true, FullWidthPaperDebuggerButton: true, - EnableCompletion: false, + EnableCitationSuggestion: false, FullDocumentRag: false, ShowedOnboarding: false, } diff --git a/pkg/gen/api/auth/v1/auth_grpc.pb.go b/pkg/gen/api/auth/v1/auth_grpc.pb.go index 3b72abb0..19f029a7 100644 --- a/pkg/gen/api/auth/v1/auth_grpc.pb.go +++ b/pkg/gen/api/auth/v1/auth_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: auth/v1/auth.proto diff --git a/pkg/gen/api/chat/v1/chat_grpc.pb.go b/pkg/gen/api/chat/v1/chat_grpc.pb.go index c0916102..59daab03 100644 --- a/pkg/gen/api/chat/v1/chat_grpc.pb.go +++ b/pkg/gen/api/chat/v1/chat_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: chat/v1/chat.proto diff --git a/pkg/gen/api/chat/v2/chat.pb.go b/pkg/gen/api/chat/v2/chat.pb.go index 3ba45df6..0d312c55 100644 --- a/pkg/gen/api/chat/v2/chat.pb.go +++ b/pkg/gen/api/chat/v2/chat.pb.go @@ -7,12 +7,13 @@ package chatv2 import ( - _ "google.golang.org/genproto/googleapis/api/annotations" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" unsafe "unsafe" + + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -1897,6 +1898,104 @@ func (*CreateConversationMessageStreamResponse_StreamError) isCreateConversation func (*CreateConversationMessageStreamResponse_ReasoningChunk) isCreateConversationMessageStreamResponse_ResponsePayload() { } +// Request to get citation keys suggestion based on project bibliography +type GetCitationKeysRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sentence string `protobuf:"bytes,1,opt,name=sentence,proto3" json:"sentence,omitempty"` + ProjectId string `protobuf:"bytes,2,opt,name=project_id,json=projectId,proto3" json:"project_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetCitationKeysRequest) Reset() { + *x = GetCitationKeysRequest{} + mi := &file_chat_v2_chat_proto_msgTypes[30] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetCitationKeysRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCitationKeysRequest) ProtoMessage() {} + +func (x *GetCitationKeysRequest) ProtoReflect() protoreflect.Message { + mi := &file_chat_v2_chat_proto_msgTypes[30] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCitationKeysRequest.ProtoReflect.Descriptor instead. +func (*GetCitationKeysRequest) Descriptor() ([]byte, []int) { + return file_chat_v2_chat_proto_rawDescGZIP(), []int{30} +} + +func (x *GetCitationKeysRequest) GetSentence() string { + if x != nil { + return x.Sentence + } + return "" +} + +func (x *GetCitationKeysRequest) GetProjectId() string { + if x != nil { + return x.ProjectId + } + return "" +} + +// Response containing the suggested citation keys +type GetCitationKeysResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + CitationKeys []string `protobuf:"bytes,1,rep,name=citation_keys,json=citationKeys,proto3" json:"citation_keys,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetCitationKeysResponse) Reset() { + *x = GetCitationKeysResponse{} + mi := &file_chat_v2_chat_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetCitationKeysResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCitationKeysResponse) ProtoMessage() {} + +func (x *GetCitationKeysResponse) ProtoReflect() protoreflect.Message { + mi := &file_chat_v2_chat_proto_msgTypes[31] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCitationKeysResponse.ProtoReflect.Descriptor instead. +func (*GetCitationKeysResponse) Descriptor() ([]byte, []int) { + return file_chat_v2_chat_proto_rawDescGZIP(), []int{31} +} + +func (x *GetCitationKeysResponse) GetCitationKeys() []string { + if x != nil { + return x.CitationKeys + } + return nil +} + var File_chat_v2_chat_proto protoreflect.FileDescriptor const file_chat_v2_chat_proto_rawDesc = "" + @@ -2030,17 +2129,24 @@ const file_chat_v2_chat_proto_rawDesc = "" + "\x13stream_finalization\x18\x06 \x01(\v2\x1b.chat.v2.StreamFinalizationH\x00R\x12streamFinalization\x129\n" + "\fstream_error\x18\a \x01(\v2\x14.chat.v2.StreamErrorH\x00R\vstreamError\x12B\n" + "\x0freasoning_chunk\x18\b \x01(\v2\x17.chat.v2.ReasoningChunkH\x00R\x0ereasoningChunkB\x12\n" + - "\x10response_payload*R\n" + + "\x10response_payload\"S\n" + + "\x16GetCitationKeysRequest\x12\x1a\n" + + "\bsentence\x18\x01 \x01(\tR\bsentence\x12\x1d\n" + + "\n" + + "project_id\x18\x02 \x01(\tR\tprojectId\">\n" + + "\x17GetCitationKeysResponse\x12#\n" + + "\rcitation_keys\x18\x01 \x03(\tR\fcitationKeys*R\n" + "\x10ConversationType\x12!\n" + "\x1dCONVERSATION_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n" + - "\x17CONVERSATION_TYPE_DEBUG\x10\x012\xa8\a\n" + + "\x17CONVERSATION_TYPE_DEBUG\x10\x012\xa7\b\n" + "\vChatService\x12\x83\x01\n" + "\x11ListConversations\x12!.chat.v2.ListConversationsRequest\x1a\".chat.v2.ListConversationsResponse\"'\x82\xd3\xe4\x93\x02!\x12\x1f/_pd/api/v2/chats/conversations\x12\x8f\x01\n" + "\x0fGetConversation\x12\x1f.chat.v2.GetConversationRequest\x1a .chat.v2.GetConversationResponse\"9\x82\xd3\xe4\x93\x023\x121/_pd/api/v2/chats/conversations/{conversation_id}\x12\xc2\x01\n" + "\x1fCreateConversationMessageStream\x12/.chat.v2.CreateConversationMessageStreamRequest\x1a0.chat.v2.CreateConversationMessageStreamResponse\":\x82\xd3\xe4\x93\x024:\x01*\"//_pd/api/v2/chats/conversations/messages/stream0\x01\x12\x9b\x01\n" + "\x12UpdateConversation\x12\".chat.v2.UpdateConversationRequest\x1a#.chat.v2.UpdateConversationResponse\"<\x82\xd3\xe4\x93\x026:\x01*21/_pd/api/v2/chats/conversations/{conversation_id}\x12\x98\x01\n" + "\x12DeleteConversation\x12\".chat.v2.DeleteConversationRequest\x1a#.chat.v2.DeleteConversationResponse\"9\x82\xd3\xe4\x93\x023*1/_pd/api/v2/chats/conversations/{conversation_id}\x12\x82\x01\n" + - "\x13ListSupportedModels\x12#.chat.v2.ListSupportedModelsRequest\x1a$.chat.v2.ListSupportedModelsResponse\" \x82\xd3\xe4\x93\x02\x1a\x12\x18/_pd/api/v2/chats/modelsB\x7f\n" + + "\x13ListSupportedModels\x12#.chat.v2.ListSupportedModelsRequest\x1a$.chat.v2.ListSupportedModelsResponse\" \x82\xd3\xe4\x93\x02\x1a\x12\x18/_pd/api/v2/chats/models\x12}\n" + + "\x0fGetCitationKeys\x12\x1f.chat.v2.GetCitationKeysRequest\x1a .chat.v2.GetCitationKeysResponse\"'\x82\xd3\xe4\x93\x02!\x12\x1f/_pd/api/v2/chats/citation-keysB\x7f\n" + "\vcom.chat.v2B\tChatProtoP\x01Z(paperdebugger/pkg/gen/api/chat/v2;chatv2\xa2\x02\x03CXX\xaa\x02\aChat.V2\xca\x02\aChat\\V2\xe2\x02\x13Chat\\V2\\GPBMetadata\xea\x02\bChat::V2b\x06proto3" var ( @@ -2056,7 +2162,7 @@ func file_chat_v2_chat_proto_rawDescGZIP() []byte { } var file_chat_v2_chat_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_chat_v2_chat_proto_msgTypes = make([]protoimpl.MessageInfo, 30) +var file_chat_v2_chat_proto_msgTypes = make([]protoimpl.MessageInfo, 32) var file_chat_v2_chat_proto_goTypes = []any{ (ConversationType)(0), // 0: chat.v2.ConversationType (*MessageTypeToolCall)(nil), // 1: chat.v2.MessageTypeToolCall @@ -2089,6 +2195,8 @@ var file_chat_v2_chat_proto_goTypes = []any{ (*StreamError)(nil), // 28: chat.v2.StreamError (*CreateConversationMessageStreamRequest)(nil), // 29: chat.v2.CreateConversationMessageStreamRequest (*CreateConversationMessageStreamResponse)(nil), // 30: chat.v2.CreateConversationMessageStreamResponse + (*GetCitationKeysRequest)(nil), // 31: chat.v2.GetCitationKeysRequest + (*GetCitationKeysResponse)(nil), // 32: chat.v2.GetCitationKeysResponse } var file_chat_v2_chat_proto_depIdxs = []int32{ 3, // 0: chat.v2.MessagePayload.system:type_name -> chat.v2.MessageTypeSystem @@ -2120,14 +2228,16 @@ var file_chat_v2_chat_proto_depIdxs = []int32{ 14, // 26: chat.v2.ChatService.UpdateConversation:input_type -> chat.v2.UpdateConversationRequest 16, // 27: chat.v2.ChatService.DeleteConversation:input_type -> chat.v2.DeleteConversationRequest 19, // 28: chat.v2.ChatService.ListSupportedModels:input_type -> chat.v2.ListSupportedModelsRequest - 11, // 29: chat.v2.ChatService.ListConversations:output_type -> chat.v2.ListConversationsResponse - 13, // 30: chat.v2.ChatService.GetConversation:output_type -> chat.v2.GetConversationResponse - 30, // 31: chat.v2.ChatService.CreateConversationMessageStream:output_type -> chat.v2.CreateConversationMessageStreamResponse - 15, // 32: chat.v2.ChatService.UpdateConversation:output_type -> chat.v2.UpdateConversationResponse - 17, // 33: chat.v2.ChatService.DeleteConversation:output_type -> chat.v2.DeleteConversationResponse - 20, // 34: chat.v2.ChatService.ListSupportedModels:output_type -> chat.v2.ListSupportedModelsResponse - 29, // [29:35] is the sub-list for method output_type - 23, // [23:29] is the sub-list for method input_type + 31, // 29: chat.v2.ChatService.GetCitationKeys:input_type -> chat.v2.GetCitationKeysRequest + 11, // 30: chat.v2.ChatService.ListConversations:output_type -> chat.v2.ListConversationsResponse + 13, // 31: chat.v2.ChatService.GetConversation:output_type -> chat.v2.GetConversationResponse + 30, // 32: chat.v2.ChatService.CreateConversationMessageStream:output_type -> chat.v2.CreateConversationMessageStreamResponse + 15, // 33: chat.v2.ChatService.UpdateConversation:output_type -> chat.v2.UpdateConversationResponse + 17, // 34: chat.v2.ChatService.DeleteConversation:output_type -> chat.v2.DeleteConversationResponse + 20, // 35: chat.v2.ChatService.ListSupportedModels:output_type -> chat.v2.ListSupportedModelsResponse + 32, // 36: chat.v2.ChatService.GetCitationKeys:output_type -> chat.v2.GetCitationKeysResponse + 30, // [30:37] is the sub-list for method output_type + 23, // [23:30] is the sub-list for method input_type 23, // [23:23] is the sub-list for extension type_name 23, // [23:23] is the sub-list for extension extendee 0, // [0:23] is the sub-list for field type_name @@ -2167,7 +2277,7 @@ func file_chat_v2_chat_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_chat_v2_chat_proto_rawDesc), len(file_chat_v2_chat_proto_rawDesc)), NumEnums: 1, - NumMessages: 30, + NumMessages: 32, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/gen/api/chat/v2/chat.pb.gw.go b/pkg/gen/api/chat/v2/chat.pb.gw.go index 81f7e4e6..3b8649de 100644 --- a/pkg/gen/api/chat/v2/chat.pb.gw.go +++ b/pkg/gen/api/chat/v2/chat.pb.gw.go @@ -237,6 +237,41 @@ func local_request_ChatService_ListSupportedModels_0(ctx context.Context, marsha return msg, metadata, err } +var filter_ChatService_GetCitationKeys_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} + +func request_ChatService_GetCitationKeys_0(ctx context.Context, marshaler runtime.Marshaler, client ChatServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetCitationKeysRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_ChatService_GetCitationKeys_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := client.GetCitationKeys(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_ChatService_GetCitationKeys_0(ctx context.Context, marshaler runtime.Marshaler, server ChatServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetCitationKeysRequest + metadata runtime.ServerMetadata + ) + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_ChatService_GetCitationKeys_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.GetCitationKeys(ctx, &protoReq) + return msg, metadata, err +} + // RegisterChatServiceHandlerServer registers the http handlers for service ChatService to "mux". // UnaryRPC :call ChatServiceServer directly. // StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. @@ -350,6 +385,26 @@ func RegisterChatServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux } forward_ChatService_ListSupportedModels_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodGet, pattern_ChatService_GetCitationKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/chat.v2.ChatService/GetCitationKeys", runtime.WithHTTPPathPattern("/_pd/api/v2/chats/citation-keys")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_ChatService_GetCitationKeys_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_ChatService_GetCitationKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) return nil } @@ -492,6 +547,23 @@ func RegisterChatServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux } forward_ChatService_ListSupportedModels_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodGet, pattern_ChatService_GetCitationKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/chat.v2.ChatService/GetCitationKeys", runtime.WithHTTPPathPattern("/_pd/api/v2/chats/citation-keys")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_ChatService_GetCitationKeys_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_ChatService_GetCitationKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) return nil } @@ -502,6 +574,7 @@ var ( pattern_ChatService_UpdateConversation_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"_pd", "api", "v2", "chats", "conversations", "conversation_id"}, "")) pattern_ChatService_DeleteConversation_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"_pd", "api", "v2", "chats", "conversations", "conversation_id"}, "")) pattern_ChatService_ListSupportedModels_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"_pd", "api", "v2", "chats", "models"}, "")) + pattern_ChatService_GetCitationKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"_pd", "api", "v2", "chats", "citation-keys"}, "")) ) var ( @@ -511,4 +584,5 @@ var ( forward_ChatService_UpdateConversation_0 = runtime.ForwardResponseMessage forward_ChatService_DeleteConversation_0 = runtime.ForwardResponseMessage forward_ChatService_ListSupportedModels_0 = runtime.ForwardResponseMessage + forward_ChatService_GetCitationKeys_0 = runtime.ForwardResponseMessage ) diff --git a/pkg/gen/api/chat/v2/chat_grpc.pb.go b/pkg/gen/api/chat/v2/chat_grpc.pb.go index 8303a8a8..bc0993b9 100644 --- a/pkg/gen/api/chat/v2/chat_grpc.pb.go +++ b/pkg/gen/api/chat/v2/chat_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: chat/v2/chat.proto @@ -25,6 +25,7 @@ const ( ChatService_UpdateConversation_FullMethodName = "/chat.v2.ChatService/UpdateConversation" ChatService_DeleteConversation_FullMethodName = "/chat.v2.ChatService/DeleteConversation" ChatService_ListSupportedModels_FullMethodName = "/chat.v2.ChatService/ListSupportedModels" + ChatService_GetCitationKeys_FullMethodName = "/chat.v2.ChatService/GetCitationKeys" ) // ChatServiceClient is the client API for ChatService service. @@ -37,6 +38,7 @@ type ChatServiceClient interface { UpdateConversation(ctx context.Context, in *UpdateConversationRequest, opts ...grpc.CallOption) (*UpdateConversationResponse, error) DeleteConversation(ctx context.Context, in *DeleteConversationRequest, opts ...grpc.CallOption) (*DeleteConversationResponse, error) ListSupportedModels(ctx context.Context, in *ListSupportedModelsRequest, opts ...grpc.CallOption) (*ListSupportedModelsResponse, error) + GetCitationKeys(ctx context.Context, in *GetCitationKeysRequest, opts ...grpc.CallOption) (*GetCitationKeysResponse, error) } type chatServiceClient struct { @@ -116,6 +118,16 @@ func (c *chatServiceClient) ListSupportedModels(ctx context.Context, in *ListSup return out, nil } +func (c *chatServiceClient) GetCitationKeys(ctx context.Context, in *GetCitationKeysRequest, opts ...grpc.CallOption) (*GetCitationKeysResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetCitationKeysResponse) + err := c.cc.Invoke(ctx, ChatService_GetCitationKeys_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // ChatServiceServer is the server API for ChatService service. // All implementations must embed UnimplementedChatServiceServer // for forward compatibility. @@ -126,6 +138,7 @@ type ChatServiceServer interface { UpdateConversation(context.Context, *UpdateConversationRequest) (*UpdateConversationResponse, error) DeleteConversation(context.Context, *DeleteConversationRequest) (*DeleteConversationResponse, error) ListSupportedModels(context.Context, *ListSupportedModelsRequest) (*ListSupportedModelsResponse, error) + GetCitationKeys(context.Context, *GetCitationKeysRequest) (*GetCitationKeysResponse, error) mustEmbedUnimplementedChatServiceServer() } @@ -154,6 +167,9 @@ func (UnimplementedChatServiceServer) DeleteConversation(context.Context, *Delet func (UnimplementedChatServiceServer) ListSupportedModels(context.Context, *ListSupportedModelsRequest) (*ListSupportedModelsResponse, error) { return nil, status.Error(codes.Unimplemented, "method ListSupportedModels not implemented") } +func (UnimplementedChatServiceServer) GetCitationKeys(context.Context, *GetCitationKeysRequest) (*GetCitationKeysResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetCitationKeys not implemented") +} func (UnimplementedChatServiceServer) mustEmbedUnimplementedChatServiceServer() {} func (UnimplementedChatServiceServer) testEmbeddedByValue() {} @@ -276,6 +292,24 @@ func _ChatService_ListSupportedModels_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } +func _ChatService_GetCitationKeys_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetCitationKeysRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ChatServiceServer).GetCitationKeys(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ChatService_GetCitationKeys_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ChatServiceServer).GetCitationKeys(ctx, req.(*GetCitationKeysRequest)) + } + return interceptor(ctx, in, info, handler) +} + // ChatService_ServiceDesc is the grpc.ServiceDesc for ChatService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -303,6 +337,10 @@ var ChatService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ListSupportedModels", Handler: _ChatService_ListSupportedModels_Handler, }, + { + MethodName: "GetCitationKeys", + Handler: _ChatService_GetCitationKeys_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/pkg/gen/api/comment/v1/comment_grpc.pb.go b/pkg/gen/api/comment/v1/comment_grpc.pb.go index b077d68b..a4217a6a 100644 --- a/pkg/gen/api/comment/v1/comment_grpc.pb.go +++ b/pkg/gen/api/comment/v1/comment_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: comment/v1/comment.proto diff --git a/pkg/gen/api/project/v1/project_grpc.pb.go b/pkg/gen/api/project/v1/project_grpc.pb.go index c50d3475..dd49e74f 100644 --- a/pkg/gen/api/project/v1/project_grpc.pb.go +++ b/pkg/gen/api/project/v1/project_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: project/v1/project.proto diff --git a/pkg/gen/api/user/v1/user.pb.go b/pkg/gen/api/user/v1/user.pb.go index c54615c4..41752992 100644 --- a/pkg/gen/api/user/v1/user.pb.go +++ b/pkg/gen/api/user/v1/user.pb.go @@ -619,7 +619,7 @@ type Settings struct { state protoimpl.MessageState `protogen:"open.v1"` ShowShortcutsAfterSelection bool `protobuf:"varint,1,opt,name=show_shortcuts_after_selection,json=showShortcutsAfterSelection,proto3" json:"show_shortcuts_after_selection,omitempty"` FullWidthPaperDebuggerButton bool `protobuf:"varint,2,opt,name=full_width_paper_debugger_button,json=fullWidthPaperDebuggerButton,proto3" json:"full_width_paper_debugger_button,omitempty"` - EnableCompletion bool `protobuf:"varint,3,opt,name=enable_completion,json=enableCompletion,proto3" json:"enable_completion,omitempty"` + EnableCitationSuggestion bool `protobuf:"varint,3,opt,name=enable_citation_suggestion,json=enableCitationSuggestion,proto3" json:"enable_citation_suggestion,omitempty"` FullDocumentRag bool `protobuf:"varint,4,opt,name=full_document_rag,json=fullDocumentRag,proto3" json:"full_document_rag,omitempty"` ShowedOnboarding bool `protobuf:"varint,5,opt,name=showed_onboarding,json=showedOnboarding,proto3" json:"showed_onboarding,omitempty"` OpenaiApiKey string `protobuf:"bytes,6,opt,name=openai_api_key,json=openaiApiKey,proto3" json:"openai_api_key,omitempty"` @@ -671,9 +671,9 @@ func (x *Settings) GetFullWidthPaperDebuggerButton() bool { return false } -func (x *Settings) GetEnableCompletion() bool { +func (x *Settings) GetEnableCitationSuggestion() bool { if x != nil { - return x.EnableCompletion + return x.EnableCitationSuggestion } return false } @@ -1153,11 +1153,11 @@ const file_user_v1_user_proto_rawDesc = "" + "\x06prompt\x18\x01 \x01(\v2\x0f.user.v1.PromptR\x06prompt\"2\n" + "\x13DeletePromptRequest\x12\x1b\n" + "\tprompt_id\x18\x01 \x01(\tR\bpromptId\"\x16\n" + - "\x14DeletePromptResponse\"\xc3\x02\n" + + "\x14DeletePromptResponse\"\xd4\x02\n" + "\bSettings\x12C\n" + "\x1eshow_shortcuts_after_selection\x18\x01 \x01(\bR\x1bshowShortcutsAfterSelection\x12F\n" + - " full_width_paper_debugger_button\x18\x02 \x01(\bR\x1cfullWidthPaperDebuggerButton\x12+\n" + - "\x11enable_completion\x18\x03 \x01(\bR\x10enableCompletion\x12*\n" + + " full_width_paper_debugger_button\x18\x02 \x01(\bR\x1cfullWidthPaperDebuggerButton\x12<\n" + + "\x1aenable_citation_suggestion\x18\x03 \x01(\bR\x18enableCitationSuggestion\x12*\n" + "\x11full_document_rag\x18\x04 \x01(\bR\x0ffullDocumentRag\x12+\n" + "\x11showed_onboarding\x18\x05 \x01(\bR\x10showedOnboarding\x12$\n" + "\x0eopenai_api_key\x18\x06 \x01(\tR\fopenaiApiKey\"\x14\n" + diff --git a/pkg/gen/api/user/v1/user_grpc.pb.go b/pkg/gen/api/user/v1/user_grpc.pb.go index 898ff765..1f96307b 100644 --- a/pkg/gen/api/user/v1/user_grpc.pb.go +++ b/pkg/gen/api/user/v1/user_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: user/v1/user.proto diff --git a/proto/chat/v2/chat.proto b/proto/chat/v2/chat.proto index 8dd650a3..779fe913 100644 --- a/proto/chat/v2/chat.proto +++ b/proto/chat/v2/chat.proto @@ -31,6 +31,9 @@ service ChatService { rpc ListSupportedModels(ListSupportedModelsRequest) returns (ListSupportedModelsResponse) { option (google.api.http) = {get: "/_pd/api/v2/chats/models"}; } + rpc GetCitationKeys(GetCitationKeysRequest) returns (GetCitationKeysResponse) { + option (google.api.http) = {get: "/_pd/api/v2/chats/citation-keys"}; + } } message MessageTypeToolCall { @@ -234,3 +237,14 @@ message CreateConversationMessageStreamResponse { ReasoningChunk reasoning_chunk = 8; } } + +// Request to get citation keys suggestion based on project bibliography +message GetCitationKeysRequest { + string sentence = 1; + string project_id = 2; +} + +// Response containing the suggested citation keys +message GetCitationKeysResponse { + repeated string citation_keys = 1; +} \ No newline at end of file diff --git a/proto/user/v1/user.proto b/proto/user/v1/user.proto index 08056803..fc7f02b5 100644 --- a/proto/user/v1/user.proto +++ b/proto/user/v1/user.proto @@ -117,7 +117,7 @@ message DeletePromptResponse {} message Settings { bool show_shortcuts_after_selection = 1; bool full_width_paper_debugger_button = 2; - bool enable_completion = 3; + bool enable_citation_suggestion = 3; bool full_document_rag = 4; bool showed_onboarding = 5; string openai_api_key = 6; diff --git a/webapp/_webapp/src/components/code-block.tsx b/webapp/_webapp/src/components/code-block.tsx index 16c71ec5..f237456d 100644 --- a/webapp/_webapp/src/components/code-block.tsx +++ b/webapp/_webapp/src/components/code-block.tsx @@ -3,7 +3,7 @@ import "highlight.js/styles/default.min.css"; import latex from "highlight.js/lib/languages/latex"; hljs.registerLanguage("latex", latex); -import { useState, useEffect } from "react"; +import { useMemo } from "react"; type CodeBlockProps = { code: string; @@ -11,11 +11,7 @@ type CodeBlockProps = { }; export const CodeBlock = ({ code, className }: CodeBlockProps) => { - const [highlightedCode, setHighlightedCode] = useState(code); - - useEffect(() => { - setHighlightedCode(hljs.highlight(code, { language: "latex" }).value); - }, [code]); + const highlightedCode = useMemo(() => hljs.highlight(code, { language: "latex" }).value, [code]); return (
 {
   // State
-  const [progress, setProgress] = useState(0);
-  const [phase, setPhase] = useState("green");
-  const [isTimeout, setIsTimeout] = useState(false);
+  const [{ progress, phase, isTimeout }, dispatch] = useReducer(indicatorReducer, {
+    progress: 0,
+    phase: "green",
+    isTimeout: false,
+  });
 
   // Handle progress animation
   useEffect(() => {
@@ -103,18 +129,18 @@ export const LoadingIndicator = ({ text = "Thinking", estimatedSeconds = 0, erro
         // we spend 100% of estimatedDuration in green,
         // 50% in orange, and 50% in red before warning.
         if (phase === "green") {
-          setPhase("orange");
+          dispatch({ type: "ADVANCE_PHASE", nextPhase: "orange" });
           currentProgress = 0;
         } else if (phase === "orange") {
-          setPhase("red");
+          dispatch({ type: "ADVANCE_PHASE", nextPhase: "red" });
           currentProgress = 0;
         } else if (phase === "red") {
-          setIsTimeout(true);
+          dispatch({ type: "SET_TIMEOUT" });
           return;
         }
       }
 
-      setProgress(currentProgress);
+      dispatch({ type: "SET_PROGRESS", progress: currentProgress });
 
       if (!isTimeout) {
         animationFrameId = requestAnimationFrame(updateProgress);
diff --git a/webapp/_webapp/src/components/message-entry-container/assistant.tsx b/webapp/_webapp/src/components/message-entry-container/assistant.tsx
index 47743da3..92f30492 100644
--- a/webapp/_webapp/src/components/message-entry-container/assistant.tsx
+++ b/webapp/_webapp/src/components/message-entry-container/assistant.tsx
@@ -132,8 +132,8 @@ export const AssistantMessageContainer = ({
             )}
 
             {/* PaperDebugger blocks */}
-            {parsedMessage.paperDebuggerContent.map((content, index) => (
-              
+            {parsedMessage.paperDebuggerContent.map((content) => (
+              
                 {content}
               
             ))}
@@ -147,7 +147,17 @@ export const AssistantMessageContainer = ({
           {((parsedMessage.regularContent?.length || 0) > 0 || parsedMessage.paperDebuggerContent.length > 0) && (
             
- + { + if (e.key === "Enter" || e.key === " ") { + handleCopy(); + } + }} + tabIndex={0} + role="button" + aria-label="Copy message" + > diff --git a/webapp/_webapp/src/components/message-entry-container/tools/general.tsx b/webapp/_webapp/src/components/message-entry-container/tools/general.tsx index f07a1c3a..db3aa39b 100644 --- a/webapp/_webapp/src/components/message-entry-container/tools/general.tsx +++ b/webapp/_webapp/src/components/message-entry-container/tools/general.tsx @@ -95,7 +95,17 @@ export const GeneralToolCard = ({ // When there is a message, show the compact card with collapsible content return (
-
+
{ + if (e.key === "Enter" || e.key === " ") { + toggleCollapse(); + } + }} + >