Skip to content

Commit 843fa7d

Browse files
feat: support openai-compatible embeddings base url (#69)
1 parent 3d3f819 commit 843fa7d

11 files changed

Lines changed: 219 additions & 11 deletions

File tree

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ WEBHOOK_MAX_COUNT=500
8989
# EMBEDDING_PROVIDER=openai
9090
# EMBEDDING_PROVIDER=google-gemini
9191
# EMBEDDING_PROVIDER_API_KEY=sk-... (required for openai and google; not used for google-gemini)
92+
# EMBEDDING_BASE_URL=https://embeddings.example.com/v1 (optional; only supported with EMBEDDING_PROVIDER=openai)
9293
# EMBEDDING_GOOGLE_CLOUD_PROJECT= (required for google-gemini; or use GOOGLE_CLOUD_PROJECT)
9394
# EMBEDDING_GOOGLE_CLOUD_LOCATION= (required for google-gemini, e.g. europe-west3; or use GOOGLE_CLOUD_LOCATION)
9495
# GOOGLE_APPLICATION_CREDENTIALS= (optional; for google-gemini when outside Google Cloud: path to service account key JSON)

charts/hub/values.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ config:
127127
DATABASE_MAX_CONN_IDLE_TIME_SECONDS: "1800"
128128
DATABASE_HEALTH_CHECK_PERIOD_SECONDS: "60"
129129
DATABASE_CONNECT_TIMEOUT_SECONDS: "10"
130+
# Optional: use with EMBEDDING_PROVIDER=openai to target a self-hosted OpenAI-compatible embeddings endpoint.
131+
# Example: http://text-embeddings-inference.default.svc.cluster.local/v1
132+
EMBEDDING_BASE_URL: ""
130133

131134
secrets:
132135
create: true

cmd/api/app.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ func setupEmbeddingSearchHandler(
9090
Provider: embeddingProviderName,
9191
ProviderAPIKey: cfg.Embedding.ProviderAPIKey,
9292
Model: embeddingModel,
93+
BaseURL: cfg.Embedding.BaseURL,
9394
Normalize: cfg.Embedding.Normalize,
9495
GoogleCloudProject: cfg.Embedding.GoogleCloudProject,
9596
GoogleCloudLocation: cfg.Embedding.GoogleCloudLocation,

cmd/backfill-embeddings/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ func run() int {
8686
Provider: providerCanonical,
8787
ProviderAPIKey: cfg.Embedding.ProviderAPIKey,
8888
Model: embeddingModel,
89+
BaseURL: cfg.Embedding.BaseURL,
8990
Normalize: cfg.Embedding.Normalize,
9091
GoogleCloudProject: cfg.Embedding.GoogleCloudProject,
9192
GoogleCloudLocation: cfg.Embedding.GoogleCloudLocation,

cmd/worker/app.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ func NewWorkerApp(cfg *config.Config, db *pgxpool.Pool) (*WorkerApp, error) {
9393
Provider: providerName,
9494
ProviderAPIKey: cfg.Embedding.ProviderAPIKey,
9595
Model: embeddingModel,
96+
BaseURL: cfg.Embedding.BaseURL,
9697
Normalize: cfg.Embedding.Normalize,
9798
GoogleCloudProject: cfg.Embedding.GoogleCloudProject,
9899
GoogleCloudLocation: cfg.Embedding.GoogleCloudLocation,

internal/config/config.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ var (
2727
ErrWebhookMaxCount = errors.New("WEBHOOK_MAX_COUNT must be a positive integer")
2828
ErrDatabaseMinConnsExceedsMax = errors.New("DATABASE_MIN_CONNS must not exceed DATABASE_MAX_CONNS")
2929
ErrInvalidPublicBaseURL = errors.New("PUBLIC_BASE_URL must be an absolute http(s) URL without query or fragment")
30+
ErrInvalidEmbeddingBaseURL = errors.New("EMBEDDING_BASE_URL must be an absolute http(s) URL without query or fragment")
3031
)
3132

3233
// DefaultDatabaseURL is the default connection URL when DATABASE_URL is unset (local/test only).
@@ -115,6 +116,7 @@ type EmbeddingConfig struct {
115116
ProviderAPIKey string `env:"EMBEDDING_PROVIDER_API_KEY"`
116117
Provider string `env:"EMBEDDING_PROVIDER"`
117118
Model string `env:"EMBEDDING_MODEL"`
119+
BaseURL string `env:"EMBEDDING_BASE_URL"`
118120
MaxConcurrent int `env:"EMBEDDING_MAX_CONCURRENT" env-default:"5"`
119121
MaxAttempts int `env:"EMBEDDING_MAX_ATTEMPTS" env-default:"3"`
120122
Normalize bool `env:"EMBEDDING_NORMALIZE" env-default:"false"`
@@ -312,42 +314,51 @@ func validate(cfg *Config) error {
312314
}
313315

314316
if cfg.Server.PublicBaseURL != "" {
315-
normalized, err := normalizePublicBaseURL(cfg.Server.PublicBaseURL)
317+
normalized, err := normalizeHTTPBaseURL(cfg.Server.PublicBaseURL, ErrInvalidPublicBaseURL)
316318
if err != nil {
317319
return err
318320
}
319321

320322
cfg.Server.PublicBaseURL = normalized
321323
}
322324

325+
if cfg.Embedding.BaseURL != "" {
326+
normalized, err := normalizeHTTPBaseURL(cfg.Embedding.BaseURL, ErrInvalidEmbeddingBaseURL)
327+
if err != nil {
328+
return err
329+
}
330+
331+
cfg.Embedding.BaseURL = normalized
332+
}
333+
323334
return nil
324335
}
325336

326-
func normalizePublicBaseURL(raw string) (string, error) {
337+
func normalizeHTTPBaseURL(raw string, sentinel error) (string, error) {
327338
trimmed := strings.TrimSpace(raw)
328339
if trimmed == "" {
329-
return "", ErrInvalidPublicBaseURL
340+
return "", sentinel
330341
}
331342

332343
parsed, err := url.Parse(trimmed)
333344
if err != nil {
334-
return "", fmt.Errorf("%w: %w", ErrInvalidPublicBaseURL, err)
345+
return "", fmt.Errorf("%w: %w", sentinel, err)
335346
}
336347

337348
if !parsed.IsAbs() || parsed.Host == "" {
338-
return "", ErrInvalidPublicBaseURL
349+
return "", sentinel
339350
}
340351

341352
if parsed.Scheme != "http" && parsed.Scheme != "https" {
342-
return "", ErrInvalidPublicBaseURL
353+
return "", sentinel
343354
}
344355

345356
if parsed.RawQuery != "" || parsed.Fragment != "" {
346-
return "", ErrInvalidPublicBaseURL
357+
return "", sentinel
347358
}
348359

349360
if parsed.User != nil {
350-
return "", ErrInvalidPublicBaseURL
361+
return "", sentinel
351362
}
352363

353364
if parsed.Path == "/" {

internal/config/config_test.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,48 @@ func TestLoad_EmbeddingGoogleCloudLocation_precedence(t *testing.T) {
281281
}
282282
}
283283

284+
func TestLoad_EmbeddingBaseURL(t *testing.T) {
285+
t.Setenv("API_KEY", "test-api-key")
286+
t.Setenv("EMBEDDING_BASE_URL", "https://embeddings.example.com/v1/")
287+
288+
cfg, err := Load()
289+
if err != nil {
290+
t.Fatalf("Load() error = %v", err)
291+
}
292+
293+
if cfg.Embedding.BaseURL != "https://embeddings.example.com/v1" {
294+
t.Errorf("Embedding.BaseURL = %q, want https://embeddings.example.com/v1", cfg.Embedding.BaseURL)
295+
}
296+
}
297+
298+
func TestLoad_EmbeddingBaseURLValidation(t *testing.T) {
299+
tests := []struct {
300+
name string
301+
value string
302+
}{
303+
{name: "rejects relative url", value: "/v1"},
304+
{name: "rejects unsupported scheme", value: "ftp://embeddings.example.com/v1"},
305+
{name: "rejects query", value: "https://embeddings.example.com/v1?x=1"},
306+
{name: "rejects fragment", value: "https://embeddings.example.com/v1#frag"},
307+
{name: "rejects user info", value: "https://user:pass@embeddings.example.com/v1"},
308+
}
309+
310+
for _, tt := range tests {
311+
t.Run(tt.name, func(t *testing.T) {
312+
t.Setenv("EMBEDDING_BASE_URL", tt.value)
313+
314+
_, err := Load()
315+
if err == nil {
316+
t.Fatalf("Load() error = nil, want error")
317+
}
318+
319+
if !errors.Is(err, ErrInvalidEmbeddingBaseURL) {
320+
t.Fatalf("Load() error = %v, want %v", err, ErrInvalidEmbeddingBaseURL)
321+
}
322+
})
323+
}
324+
}
325+
284326
func TestLoad_PublicBaseURLValidation(t *testing.T) {
285327
tests := []struct {
286328
name string
@@ -582,7 +624,7 @@ func TestNormalizePublicBaseURL(t *testing.T) {
582624

583625
for _, tt := range tests {
584626
t.Run(tt.name, func(t *testing.T) {
585-
got, err := normalizePublicBaseURL(tt.value)
627+
got, err := normalizeHTTPBaseURL(tt.value, ErrInvalidPublicBaseURL)
586628
if err != nil {
587629
t.Fatalf("normalizePublicBaseURL() error = %v, want nil", err)
588630
}
@@ -597,7 +639,7 @@ func TestNormalizePublicBaseURL(t *testing.T) {
597639
func TestNormalizePublicBaseURLRejectsInvalidValues(t *testing.T) {
598640
for _, value := range []string{" ", "http://[::1", "hub.example.com"} {
599641
t.Run(value, func(t *testing.T) {
600-
_, err := normalizePublicBaseURL(value)
642+
_, err := normalizeHTTPBaseURL(value, ErrInvalidPublicBaseURL)
601643
if !errors.Is(err, ErrInvalidPublicBaseURL) {
602644
t.Fatalf("normalizePublicBaseURL() error = %v, want %v", err, ErrInvalidPublicBaseURL)
603645
}

internal/openai/client.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var (
2929
// Client calls the OpenAI embeddings API via the official SDK.
3030
type Client struct {
3131
sdk openaisdk.Client
32+
baseURL string
3233
dimensions int
3334
model string
3435
normalize bool
@@ -51,6 +52,13 @@ func WithModel(model string) ClientOption {
5152
}
5253
}
5354

55+
// WithBaseURL sets a custom OpenAI-compatible base URL (for example a self-hosted embeddings runtime).
56+
func WithBaseURL(baseURL string) ClientOption {
57+
return func(c *Client) {
58+
c.baseURL = baseURL
59+
}
60+
}
61+
5462
// WithNormalize enables L2 normalization of the embedding vector before returning (e.g. before storing or caching).
5563
func WithNormalize(normalize bool) ClientOption {
5664
return func(c *Client) {
@@ -62,14 +70,22 @@ func WithNormalize(normalize bool) ClientOption {
6270
// Embedding dimension is fixed (models.EmbeddingVectorDimensions); WithDimensions is optional for overrides.
6371
func NewClient(apiKey string, opts ...ClientOption) *Client {
6472
client := &Client{
65-
sdk: openaisdk.NewClient(option.WithAPIKey(apiKey)),
6673
dimensions: models.EmbeddingVectorDimensions,
6774
}
6875

6976
for _, opt := range opts {
7077
opt(client)
7178
}
7279

80+
sdkOpts := []option.RequestOption{
81+
option.WithAPIKey(apiKey),
82+
}
83+
if client.baseURL != "" {
84+
sdkOpts = append(sdkOpts, option.WithBaseURL(client.baseURL))
85+
}
86+
87+
client.sdk = openaisdk.NewClient(sdkOpts...)
88+
7389
return client
7490
}
7591

internal/openai/client_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"sync/atomic"
9+
"testing"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
type embeddingRequest struct {
16+
Input string `json:"input"`
17+
Model string `json:"model"`
18+
Dimensions int `json:"dimensions"`
19+
}
20+
21+
func newEmbeddingServer(t *testing.T, embedding []float64) (*httptest.Server, *atomic.Int32) {
22+
t.Helper()
23+
24+
var hitCount atomic.Int32
25+
26+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27+
hitCount.Add(1)
28+
assert.Equal(t, "/v1/embeddings", r.URL.Path)
29+
assert.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
30+
31+
var req embeddingRequest
32+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
33+
t.Errorf("decode request body: %v", err)
34+
http.Error(w, "invalid request", http.StatusBadRequest)
35+
36+
return
37+
}
38+
39+
assert.Equal(t, "hello world", req.Input)
40+
assert.Equal(t, "test-model", req.Model)
41+
assert.Equal(t, 2, req.Dimensions)
42+
43+
w.Header().Set("Content-Type", "application/json")
44+
45+
if err := json.NewEncoder(w).Encode(map[string]any{
46+
"object": "list",
47+
"model": req.Model,
48+
"data": []map[string]any{
49+
{
50+
"object": "embedding",
51+
"index": 0,
52+
"embedding": embedding,
53+
},
54+
},
55+
"usage": map[string]any{
56+
"prompt_tokens": 1,
57+
"total_tokens": 1,
58+
},
59+
}); err != nil {
60+
t.Errorf("encode response body: %v", err)
61+
}
62+
}))
63+
64+
t.Cleanup(server.Close)
65+
66+
return server, &hitCount
67+
}
68+
69+
func TestCreateEmbedding_UsesExplicitBaseURLOverEnvironment(t *testing.T) {
70+
envServer, envHits := newEmbeddingServer(t, []float64{9, 9})
71+
explicitServer, explicitHits := newEmbeddingServer(t, []float64{1, 2})
72+
73+
t.Setenv("OPENAI_BASE_URL", envServer.URL+"/v1")
74+
75+
client := NewClient("sk-test",
76+
WithBaseURL(explicitServer.URL+"/v1"),
77+
WithDimensions(2),
78+
WithModel("test-model"),
79+
)
80+
81+
embedding, err := client.CreateEmbedding(context.Background(), "hello world")
82+
require.NoError(t, err)
83+
assert.Equal(t, []float32{1, 2}, embedding)
84+
assert.Equal(t, int32(0), envHits.Load())
85+
assert.Equal(t, int32(1), explicitHits.Load())
86+
}
87+
88+
func TestCreateEmbedding_UsesEnvironmentBaseURLWhenExplicitBaseURLIsUnset(t *testing.T) {
89+
envServer, envHits := newEmbeddingServer(t, []float64{3, 4})
90+
91+
t.Setenv("OPENAI_BASE_URL", envServer.URL+"/v1")
92+
93+
client := NewClient("sk-test",
94+
WithDimensions(2),
95+
WithModel("test-model"),
96+
)
97+
98+
embedding, err := client.CreateEmbedding(context.Background(), "hello world")
99+
require.NoError(t, err)
100+
assert.Equal(t, []float32{3, 4}, embedding)
101+
assert.Equal(t, int32(1), envHits.Load())
102+
}

internal/service/embedding_client_factory.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ var (
2424
ErrEmbeddingConfigInvalid = errors.New("embedding config invalid")
2525
// ErrEmbeddingProviderAPIKey is returned when an API-key-based provider is configured without a key.
2626
ErrEmbeddingProviderAPIKey = errors.New("EMBEDDING_PROVIDER_API_KEY is required for this provider")
27+
// ErrEmbeddingBaseURLUnsupported is returned when a custom base URL is configured for a non-openai provider.
28+
ErrEmbeddingBaseURLUnsupported = errors.New("EMBEDDING_BASE_URL is only supported for openai")
2729
// ErrEmbeddingGoogleGeminiConfig is returned when google-gemini is configured without project or location.
2830
ErrEmbeddingGoogleGeminiConfig = errors.New(
2931
"google-gemini requires EMBEDDING_GOOGLE_CLOUD_PROJECT and EMBEDDING_GOOGLE_CLOUD_LOCATION")
@@ -59,6 +61,7 @@ var embeddingProviderRegistry = map[string]providerEntry{
5961
func openAIEmbeddingFactory(_ context.Context, cfg EmbeddingClientConfig) (EmbeddingClient, error) {
6062
return openai.NewClient(cfg.ProviderAPIKey,
6163
openai.WithModel(cfg.Model),
64+
openai.WithBaseURL(cfg.BaseURL),
6265
openai.WithNormalize(cfg.Normalize),
6366
), nil
6467
}
@@ -102,6 +105,7 @@ type EmbeddingClientConfig struct {
102105
Provider string
103106
ProviderAPIKey string // API key for openai/google providers; not logged or serialized
104107
Model string
108+
BaseURL string
105109
Normalize bool
106110
GoogleCloudProject string
107111
GoogleCloudLocation string
@@ -121,6 +125,10 @@ func ValidateEmbeddingConfig(cfg EmbeddingClientConfig) error {
121125
return fmt.Errorf("%w: %s", ErrEmbeddingProviderAPIKey, provider)
122126
}
123127

128+
if cfg.BaseURL != "" && provider != EmbeddingProviderOpenAI {
129+
return fmt.Errorf("%w: %s", ErrEmbeddingBaseURLUnsupported, provider)
130+
}
131+
124132
if entry.RequiresGoogleGeminiConfig && (cfg.GoogleCloudProject == "" || cfg.GoogleCloudLocation == "") {
125133
return ErrEmbeddingGoogleGeminiConfig
126134
}

0 commit comments

Comments
 (0)