Skip to content

Commit 54514a8

Browse files
authored
Merge pull request #129 from PaperDebugger/feat/byok
feat: byok
2 parents adb9032 + 0022664 commit 54514a8

File tree

18 files changed

+799
-156
lines changed

18 files changed

+799
-156
lines changed

internal/api/chat/create_conversation_message_stream_v2.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"paperdebugger/internal/models"
99
"paperdebugger/internal/services"
1010
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
11+
"strings"
1112

1213
"github.com/google/uuid"
1314
"github.com/openai/openai-go/v3"
@@ -276,9 +277,40 @@ func (s *ChatServerV2) CreateConversationMessageStream(
276277
return s.sendStreamError(stream, err)
277278
}
278279

280+
// Check if user has an API key for requested model
281+
var llmProvider *models.LLMProviderConfig
282+
var customModel *models.CustomModel
283+
customModel = nil
284+
for i := range settings.CustomModels {
285+
if settings.CustomModels[i].Slug == modelSlug {
286+
customModel = &settings.CustomModels[i]
287+
}
288+
}
289+
279290
// Usage is the same as ChatCompletion, just passing the stream parameter
280-
llmProvider := &models.LLMProviderConfig{
281-
APIKey: settings.OpenAIAPIKey,
291+
292+
if customModel == nil {
293+
// User did not specify API key for this model
294+
llmProvider = &models.LLMProviderConfig{
295+
APIKey: "",
296+
IsCustomModel: false,
297+
}
298+
} else {
299+
customModel.BaseUrl = strings.ToLower(customModel.BaseUrl)
300+
301+
if strings.Contains(customModel.BaseUrl, "paperdebugger.com") {
302+
customModel.BaseUrl = ""
303+
}
304+
if !strings.HasPrefix(customModel.BaseUrl, "https://") {
305+
customModel.BaseUrl = strings.Replace(customModel.BaseUrl, "http://", "", 1)
306+
customModel.BaseUrl = "https://" + customModel.BaseUrl
307+
}
308+
309+
llmProvider = &models.LLMProviderConfig{
310+
APIKey: customModel.APIKey,
311+
Endpoint: customModel.BaseUrl,
312+
IsCustomModel: true,
313+
}
282314
}
283315

284316
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
@@ -307,7 +339,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
307339
for i, bsonMsg := range conversation.InappChatHistory {
308340
protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg)
309341
}
310-
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider)
342+
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider, modelSlug)
311343
if err != nil {
312344
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
313345
return

internal/api/chat/list_supported_models_v2.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package chat
22

33
import (
44
"context"
5-
"strings"
65

76
"paperdebugger/internal/libs/contextutil"
87
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
@@ -220,32 +219,33 @@ func (s *ChatServerV2) ListSupportedModels(
220219
return nil, err
221220
}
222221

223-
hasOwnAPIKey := strings.TrimSpace(settings.OpenAIAPIKey) != ""
224-
225222
var models []*chatv2.SupportedModel
226-
for _, config := range allModels {
227-
// Choose the appropriate slug based on whether user has their own API key.
228-
//
229-
// Some models are only available via OpenRouter; for those, slugOpenAI may be empty.
230-
// In that case, keep using the OpenRouter slug to avoid returning an empty model slug.
231-
slug := config.slugOpenRouter
232-
if hasOwnAPIKey && strings.TrimSpace(config.slugOpenAI) != "" {
233-
slug = config.slugOpenAI
234-
}
235223

224+
for _, model := range settings.CustomModels {
225+
models = append(models, &chatv2.SupportedModel{
226+
Name: model.Name,
227+
Slug: model.Slug,
228+
TotalContext: int64(model.ContextWindow),
229+
MaxOutput: int64(model.MaxOutput),
230+
InputPrice: int64(model.InputPrice),
231+
OutputPrice: int64(model.OutputPrice),
232+
IsCustom: true,
233+
})
234+
}
235+
236+
for _, config := range allModels {
236237
model := &chatv2.SupportedModel{
237238
Name: config.name,
238-
Slug: slug,
239+
Slug: config.slugOpenRouter,
239240
TotalContext: config.totalContext,
240241
MaxOutput: config.maxOutput,
241242
InputPrice: config.inputPrice,
242243
OutputPrice: config.outputPrice,
243244
}
244245

245246
// If model requires own key but user hasn't provided one, mark as disabled
246-
if config.requireOwnKey && !hasOwnAPIKey {
247-
model.Disabled = true
248-
model.DisabledReason = stringPtr("Requires your own OpenAI API key. Configure it in Settings.")
247+
if config.requireOwnKey {
248+
continue
249249
}
250250

251251
models = append(models, model)

internal/api/mapper/user.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,69 @@ package mapper
33
import (
44
"paperdebugger/internal/models"
55
userv1 "paperdebugger/pkg/gen/api/user/v1"
6+
7+
"go.mongodb.org/mongo-driver/v2/bson"
68
)
79

810
func MapProtoSettingsToModel(settings *userv1.Settings) *models.Settings {
11+
// Map the slice of custom models
12+
customModels := make([]models.CustomModel, len(settings.CustomModels))
13+
for i, m := range settings.CustomModels {
14+
var id bson.ObjectID
15+
16+
id, err := bson.ObjectIDFromHex(m.Id)
17+
if err != nil {
18+
id = bson.NewObjectID()
19+
}
20+
21+
customModels[i] = models.CustomModel{
22+
Id: id,
23+
Slug: m.Slug,
24+
Name: m.Name,
25+
BaseUrl: m.BaseUrl,
26+
APIKey: m.ApiKey,
27+
ContextWindow: m.ContextWindow,
28+
MaxOutput: m.MaxOutput,
29+
InputPrice: m.InputPrice,
30+
OutputPrice: m.OutputPrice,
31+
}
32+
}
33+
934
return &models.Settings{
1035
ShowShortcutsAfterSelection: settings.ShowShortcutsAfterSelection,
1136
FullWidthPaperDebuggerButton: settings.FullWidthPaperDebuggerButton,
12-
EnableCitationSuggestion: settings.EnableCitationSuggestion,
37+
EnableCitationSuggestion: settings.EnableCitationSuggestion,
1338
FullDocumentRag: settings.FullDocumentRag,
1439
ShowedOnboarding: settings.ShowedOnboarding,
1540
OpenAIAPIKey: settings.OpenaiApiKey,
41+
CustomModels: customModels,
1642
}
1743
}
1844

1945
func MapModelSettingsToProto(settings *models.Settings) *userv1.Settings {
46+
// Map the slice back to Proto
47+
customModels := make([]*userv1.CustomModel, len(settings.CustomModels))
48+
for i, m := range settings.CustomModels {
49+
customModels[i] = &userv1.CustomModel{
50+
Id: m.Id.Hex(),
51+
Slug: m.Slug,
52+
Name: m.Name,
53+
BaseUrl: m.BaseUrl,
54+
ApiKey: m.APIKey,
55+
ContextWindow: m.ContextWindow,
56+
MaxOutput: m.MaxOutput,
57+
InputPrice: m.InputPrice,
58+
OutputPrice: m.OutputPrice,
59+
}
60+
}
61+
2062
return &userv1.Settings{
2163
ShowShortcutsAfterSelection: settings.ShowShortcutsAfterSelection,
2264
FullWidthPaperDebuggerButton: settings.FullWidthPaperDebuggerButton,
23-
EnableCitationSuggestion: settings.EnableCitationSuggestion,
65+
EnableCitationSuggestion: settings.EnableCitationSuggestion,
2466
FullDocumentRag: settings.FullDocumentRag,
2567
ShowedOnboarding: settings.ShowedOnboarding,
2668
OpenaiApiKey: settings.OpenAIAPIKey,
69+
CustomModels: customModels,
2770
}
2871
}

internal/models/llm_provider.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package models
22

33
// LLMProviderConfig holds the configuration for LLM API calls.
44
// If both Endpoint and APIKey are empty, the system default will be used.
5+
// If IsCustomModel is true, the user-requested slug with corresponding
6+
// API keys and endpoint should be used.
57
type LLMProviderConfig struct {
6-
Endpoint string
7-
APIKey string
8-
ModelName string
8+
Endpoint string
9+
APIKey string
10+
ModelName string
11+
IsCustomModel bool
912
}
1013

1114
// IsCustom returns true if the user has configured custom LLM provider settings.

internal/models/user.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,26 @@ package models
22

33
import "go.mongodb.org/mongo-driver/v2/bson"
44

5+
type CustomModel struct {
6+
Id bson.ObjectID `bson:"_id"`
7+
Slug string `bson:"slug"`
8+
Name string `bson:"name"`
9+
BaseUrl string `bson:"base_url"`
10+
APIKey string `bson:"api_key"`
11+
ContextWindow int32 `bson:"context_window"`
12+
MaxOutput int32 `bson:"max_output"`
13+
InputPrice int32 `bson:"input_price"`
14+
OutputPrice int32 `bson:"output_price"`
15+
}
16+
517
type Settings struct {
6-
ShowShortcutsAfterSelection bool `bson:"show_shortcuts_after_selection"`
7-
FullWidthPaperDebuggerButton bool `bson:"full_width_paper_debugger_button"`
8-
EnableCitationSuggestion bool `bson:"enable_citation_suggestion"`
9-
FullDocumentRag bool `bson:"full_document_rag"`
10-
ShowedOnboarding bool `bson:"showed_onboarding"`
11-
OpenAIAPIKey string `bson:"openai_api_key"`
18+
ShowShortcutsAfterSelection bool `bson:"show_shortcuts_after_selection"`
19+
FullWidthPaperDebuggerButton bool `bson:"full_width_paper_debugger_button"`
20+
EnableCitationSuggestion bool `bson:"enable_citation_suggestion"`
21+
FullDocumentRag bool `bson:"full_document_rag"`
22+
ShowedOnboarding bool `bson:"showed_onboarding"`
23+
OpenAIAPIKey string `bson:"openai_api_key"`
24+
CustomModels []CustomModel `bson:"custom_models"`
1225
}
1326

1427
type User struct {

internal/services/toolkit/client/client_v2.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,20 @@ func (a *AIClientV2) GetOpenAIClient(llmConfig *models.LLMProviderConfig) *opena
3232
var Endpoint string = llmConfig.Endpoint
3333
var APIKey string = llmConfig.APIKey
3434

35-
if Endpoint == "" {
36-
if APIKey != "" {
37-
// User provided their own API key, use the OpenAI-compatible endpoint
38-
Endpoint = a.cfg.OpenAIBaseURL // standard openai base url
39-
} else {
40-
// suffix needed for cloudflare gateway
41-
Endpoint = a.cfg.InferenceBaseURL + "/openrouter"
35+
if !llmConfig.IsCustomModel {
36+
if Endpoint == "" {
37+
if APIKey != "" {
38+
// User provided their own API key, use the OpenAI-compatible endpoint
39+
Endpoint = a.cfg.OpenAIBaseURL // standard openai base url
40+
} else {
41+
// suffix needed for cloudflare gateway
42+
Endpoint = a.cfg.InferenceBaseURL + "/openrouter"
43+
}
4244
}
43-
}
4445

45-
if APIKey == "" {
46-
APIKey = a.cfg.InferenceAPIKey
46+
if APIKey == "" {
47+
APIKey = a.cfg.InferenceAPIKey
48+
}
4749
}
4850

4951
opts := []option.RequestOption{

internal/services/toolkit/client/completion_v2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
6666
}()
6767

6868
oaiClient := a.GetOpenAIClient(llmProvider)
69-
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry)
69+
params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, llmProvider.IsCustomModel)
7070

7171
for {
7272
params.Messages = openaiChatHistory

internal/services/toolkit/client/get_conversation_title_v2.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"github.com/samber/lo"
1414
)
1515

16-
func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig) (string, error) {
16+
func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string) (string, error) {
1717
messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string {
1818
if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok {
1919
return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent())
@@ -29,7 +29,13 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor
2929
message := strings.Join(messages, "\n")
3030
message = fmt.Sprintf("%s\nBased on above conversation, generate a short, clear, and descriptive title that summarizes the main topic or purpose of the discussion. The title should be concise, specific, and use natural language. Avoid vague or generic titles. Use abbreviation and short words if possible. Use 3-5 words if possible. Give me the title only, no other text including any other words.", message)
3131

32-
_, resp, err := a.ChatCompletionV2(ctx, "gpt-5-nano", OpenAIChatHistory{
32+
// Default model if user is not using their own
33+
modelToUse := "gpt-5-nano"
34+
if llmProvider.IsCustomModel {
35+
modelToUse = modelSlug
36+
}
37+
38+
_, resp, err := a.ChatCompletionV2(ctx, modelToUse, OpenAIChatHistory{
3339
openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."),
3440
openai.UserMessage(message),
3541
}, llmProvider)

internal/services/toolkit/client/utils_v2.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func appendAssistantTextResponseV2(openaiChatHistory *OpenAIChatHistory, inappCh
5353
})
5454
}
5555

56-
func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2) openaiv3.ChatCompletionNewParams {
56+
func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, isCustomModel bool) openaiv3.ChatCompletionNewParams {
5757
var reasoningModels = []string{
5858
"gpt-5",
5959
"gpt-5-mini",
@@ -66,6 +66,18 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2)
6666
"o1",
6767
"codex-mini-latest",
6868
}
69+
70+
// Other model providers generally do not support the Store param
71+
if isCustomModel {
72+
return openaiv3.ChatCompletionNewParams{
73+
Model: modelSlug,
74+
Temperature: openaiv3.Float(0.7),
75+
MaxCompletionTokens: openaiv3.Int(4000),
76+
Tools: toolRegistry.GetTools(),
77+
ParallelToolCalls: openaiv3.Bool(true),
78+
}
79+
}
80+
6981
for _, model := range reasoningModels {
7082
if strings.Contains(modelSlug, model) {
7183
return openaiv3.ChatCompletionNewParams{

pkg/gen/api/chat/v2/chat.pb.go

Lines changed: 14 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)