|
| 1 | +package client |
| 2 | + |
| 3 | +// TODO: This file should not place in the client package. |
| 4 | +import ( |
| 5 | + "context" |
| 6 | + "fmt" |
| 7 | + "paperdebugger/internal/models" |
| 8 | + "paperdebugger/internal/services/toolkit/tools/xtramcp" |
| 9 | + "regexp" |
| 10 | + "strings" |
| 11 | + |
| 12 | + "github.com/openai/openai-go/v3" |
| 13 | + "go.mongodb.org/mongo-driver/v2/bson" |
| 14 | +) |
| 15 | + |
| 16 | +var ( |
| 17 | + // Regex patterns compiled once |
| 18 | + titleFieldRe = regexp.MustCompile(`(?i)title\s*=\s*`) // matches "title = " prefix |
| 19 | + entryStartRe = regexp.MustCompile(`(?i)^\s*@(\w+)\s*\{`) // eg. @article{ |
| 20 | + stringEntryRe = regexp.MustCompile(`(?i)^\s*@String\s*\{`) // eg. @String{ |
| 21 | + multiSpaceRe = regexp.MustCompile(` {2,}`) |
| 22 | + |
| 23 | + // Fields to exclude from bibliography (not useful for citation matching) |
| 24 | + excludedFields = []string{ |
| 25 | + "address", "institution", "pages", "eprint", "primaryclass", "volume", "number", |
| 26 | + "edition", "numpages", "articleno", "publisher", "editor", "doi", "url", "acmid", |
| 27 | + "issn", "archivePrefix", "year", "month", "day", "eid", "lastaccessed", "organization", |
| 28 | + "school", "isbn", "mrclass", "mrnumber", "mrreviewer", "type", "order_no", "location", |
| 29 | + "howpublished", "distincturl", "issue_date", "archived", "series", "source", |
| 30 | + } |
| 31 | + excludeFieldRe = regexp.MustCompile(`(?i)^\s*(` + strings.Join(excludedFields, "|") + `)\s*=`) |
| 32 | +) |
| 33 | + |
| 34 | +// braceBalance returns the net brace count (opens - closes) in a string. |
| 35 | +func braceBalance(s string) int { |
| 36 | + return strings.Count(s, "{") - strings.Count(s, "}") |
| 37 | +} |
| 38 | + |
| 39 | +// isQuoteUnclosed returns true if the string has an odd number of double quotes. |
| 40 | +func isQuoteUnclosed(s string) bool { |
| 41 | + return strings.Count(s, `"`)%2 == 1 |
| 42 | +} |
| 43 | + |
| 44 | +// extractBalancedValue extracts a BibTeX field value (braced or quoted) starting at pos. |
| 45 | +// It is needed for (1) getting full title (for abstract lookup) and (2) skipping excluded |
| 46 | +// fields that may span multiple lines. |
| 47 | +// Returns the extracted content and end position, or empty string and -1 if not found. |
| 48 | +func extractBalancedValue(s string, pos int) (string, int) { |
| 49 | + // Skip whitespace |
| 50 | + for pos < len(s) && (s[pos] == ' ' || s[pos] == '\t' || s[pos] == '\n' || s[pos] == '\r') { |
| 51 | + pos++ |
| 52 | + } |
| 53 | + if pos >= len(s) { |
| 54 | + return "", -1 |
| 55 | + } |
| 56 | + |
| 57 | + switch s[pos] { |
| 58 | + case '{': |
| 59 | + depth := 0 |
| 60 | + start := pos + 1 |
| 61 | + for i := pos; i < len(s); i++ { |
| 62 | + switch s[i] { |
| 63 | + case '{': |
| 64 | + depth++ |
| 65 | + case '}': |
| 66 | + depth-- |
| 67 | + if depth == 0 { |
| 68 | + return s[start:i], i + 1 |
| 69 | + } |
| 70 | + } |
| 71 | + } |
| 72 | + case '"': |
| 73 | + start := pos + 1 |
| 74 | + for i := start; i < len(s); i++ { |
| 75 | + if s[i] == '"' { |
| 76 | + return s[start:i], i + 1 |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + return "", -1 |
| 81 | +} |
| 82 | + |
| 83 | +// extractTitle extracts the title from a BibTeX entry string. |
| 84 | +// It handles nested braces like title = {A Study of {COVID-19}}. |
| 85 | +func extractTitle(entry string) string { |
| 86 | + loc := titleFieldRe.FindStringIndex(entry) |
| 87 | + if loc == nil { |
| 88 | + return "" |
| 89 | + } |
| 90 | + content, _ := extractBalancedValue(entry, loc[1]) |
| 91 | + return strings.TrimSpace(content) |
| 92 | +} |
| 93 | + |
| 94 | +// parseBibFile extracts bibliography entries from a .bib file's lines, |
| 95 | +// filtering out @String macros, comments, and excluded fields (url, doi, etc.). |
| 96 | +func parseBibFile(lines []string) []string { |
| 97 | + var entries []string |
| 98 | + var currentEntry []string |
| 99 | + |
| 100 | + // It handles multi-line field values by tracking brace/quote balance: |
| 101 | + // - skipBraces > 0: currently skipping a {bracketed} value, wait until balanced |
| 102 | + // - skipQuotes = true: currently skipping a "quoted" value, wait for closing quote |
| 103 | + |
| 104 | + var entryDepth int // brace depth for current entry (0 = entry complete) |
| 105 | + var skipBraces int // > 0 means we're skipping lines until braces balance |
| 106 | + var skipQuotes bool // true means we're skipping lines until closing quote |
| 107 | + |
| 108 | + for _, line := range lines { |
| 109 | + // Skip empty lines and comments |
| 110 | + if trimmed := strings.TrimSpace(line); trimmed == "" || strings.HasPrefix(trimmed, "%") { |
| 111 | + continue |
| 112 | + } |
| 113 | + |
| 114 | + // If skipping a multi-line {bracketed} field value, keep skipping until balanced |
| 115 | + if skipBraces > 0 { |
| 116 | + skipBraces += braceBalance(line) |
| 117 | + continue |
| 118 | + } |
| 119 | + |
| 120 | + // If skipping a multi-line "quoted" field value, keep skipping until closing quote |
| 121 | + if skipQuotes { |
| 122 | + if isQuoteUnclosed(line) { // odd quote count = found closing quote |
| 123 | + skipQuotes = false |
| 124 | + } |
| 125 | + continue |
| 126 | + } |
| 127 | + |
| 128 | + // Skip @String{...} macro definitions |
| 129 | + if stringEntryRe.MatchString(line) { |
| 130 | + skipBraces = braceBalance(line) |
| 131 | + continue |
| 132 | + } |
| 133 | + |
| 134 | + // Skip excluded fields (url, doi, pages, etc.) - may span multiple lines |
| 135 | + if excludeFieldRe.MatchString(line) { |
| 136 | + if strings.Contains(line, "={") || strings.Contains(line, "= {") { |
| 137 | + skipBraces = braceBalance(line) |
| 138 | + } else if strings.Contains(line, `="`) || strings.Contains(line, `= "`) { |
| 139 | + skipQuotes = isQuoteUnclosed(line) |
| 140 | + } |
| 141 | + continue |
| 142 | + } |
| 143 | + |
| 144 | + // Start of new entry: @article{key, or @book{key, etc. |
| 145 | + if entryStartRe.MatchString(line) { |
| 146 | + if len(currentEntry) > 0 { |
| 147 | + entries = append(entries, strings.Join(currentEntry, "\n")) |
| 148 | + } |
| 149 | + currentEntry = []string{line} |
| 150 | + entryDepth = braceBalance(line) |
| 151 | + continue |
| 152 | + } |
| 153 | + |
| 154 | + // Continue building current entry |
| 155 | + if len(currentEntry) > 0 { |
| 156 | + currentEntry = append(currentEntry, line) |
| 157 | + entryDepth += braceBalance(line) |
| 158 | + if entryDepth <= 0 { // entry complete when braces balance |
| 159 | + entries = append(entries, strings.Join(currentEntry, "\n")) |
| 160 | + currentEntry = nil |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + // Last entry if file doesn't end with balanced braces |
| 166 | + if len(currentEntry) > 0 { |
| 167 | + entries = append(entries, strings.Join(currentEntry, "\n")) |
| 168 | + } |
| 169 | + return entries |
| 170 | +} |
| 171 | + |
| 172 | +// fetchAbstracts enriches entries with abstracts from XtraMCP using batch API. |
| 173 | +func (a *AIClientV2) fetchAbstracts(ctx context.Context, entries []string) []string { |
| 174 | + // Extract titles |
| 175 | + var titles []string |
| 176 | + for _, entry := range entries { |
| 177 | + if title := extractTitle(entry); title != "" { |
| 178 | + titles = append(titles, title) |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + // Fetch abstracts and build lookup map |
| 183 | + abstracts := make(map[string]string) |
| 184 | + svc := xtramcp.NewXtraMCPServices(a.cfg.XtraMCPURI) |
| 185 | + resp, err := svc.GetPaperAbstracts(ctx, titles) |
| 186 | + if err == nil && resp.Success { |
| 187 | + for _, r := range resp.Results { |
| 188 | + if r.Found { |
| 189 | + abstracts[r.Title] = r.Abstract |
| 190 | + } |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + // Enrich entries |
| 195 | + result := make([]string, len(entries)) |
| 196 | + for i, entry := range entries { |
| 197 | + if abstract, ok := abstracts[extractTitle(entry)]; ok && abstract != "" { |
| 198 | + if pos := strings.LastIndex(entry, "}"); pos > 0 { |
| 199 | + result[i] = entry[:pos] + fmt.Sprintf(",\n abstract = {%s}\n}", abstract) |
| 200 | + continue |
| 201 | + } |
| 202 | + } |
| 203 | + result[i] = entry |
| 204 | + } |
| 205 | + return result |
| 206 | +} |
| 207 | + |
| 208 | +// GetBibliographyForCitation extracts bibliography content from a project's .bib files. |
| 209 | +// It excludes non-essential fields to save tokens and fetches abstracts from XtraMCP. |
| 210 | +func (a *AIClientV2) GetBibliographyForCitation(ctx context.Context, userId bson.ObjectID, projectId string) (string, error) { |
| 211 | + project, err := a.projectService.GetProject(ctx, userId, projectId) |
| 212 | + if err != nil { |
| 213 | + return "", err |
| 214 | + } |
| 215 | + |
| 216 | + // Parse all .bib files |
| 217 | + var entries []string |
| 218 | + for _, doc := range project.Docs { |
| 219 | + if strings.HasSuffix(doc.Filepath, ".bib") { |
| 220 | + entries = append(entries, parseBibFile(doc.Lines)...) |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + // Enrich with abstracts |
| 225 | + entries = a.fetchAbstracts(ctx, entries) |
| 226 | + |
| 227 | + // Join and normalize |
| 228 | + bibliography := strings.Join(entries, "\n") |
| 229 | + return multiSpaceRe.ReplaceAllString(bibliography, " "), nil |
| 230 | +} |
| 231 | + |
| 232 | +func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userId bson.ObjectID, projectId string, llmProvider *models.LLMProviderConfig) ([]string, error) { |
| 233 | + bibliography, err := a.GetBibliographyForCitation(ctx, userId, projectId) |
| 234 | + |
| 235 | + if err != nil { |
| 236 | + return nil, err |
| 237 | + } |
| 238 | + |
| 239 | + emptyCitation := "none" |
| 240 | + |
| 241 | + // Bibliography is placed at the start of the prompt to leverage prompt caching |
| 242 | + message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation) |
| 243 | + |
| 244 | + _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ |
| 245 | + openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), |
| 246 | + openai.UserMessage(message), |
| 247 | + }, llmProvider) |
| 248 | + |
| 249 | + if err != nil { |
| 250 | + return nil, err |
| 251 | + } |
| 252 | + |
| 253 | + if len(resp) == 0 { |
| 254 | + return []string{}, nil |
| 255 | + } |
| 256 | + |
| 257 | + citationKeysStr := strings.TrimSpace(resp[0].Payload.GetAssistant().GetContent()) |
| 258 | + |
| 259 | + if citationKeysStr == "" || citationKeysStr == emptyCitation { |
| 260 | + return []string{}, nil |
| 261 | + } |
| 262 | + |
| 263 | + // Parse comma-separated keys |
| 264 | + var result []string |
| 265 | + for _, key := range strings.Split(citationKeysStr, ",") { |
| 266 | + if trimmed := strings.TrimSpace(key); trimmed != "" { |
| 267 | + result = append(result, trimmed) |
| 268 | + } |
| 269 | + } |
| 270 | + |
| 271 | + return result, nil |
| 272 | +} |
0 commit comments