Skip to content

Commit b217602

Browse files
Add embedding gateway override for semantic search (#697)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
1 parent 922d48f commit b217602

5 files changed

Lines changed: 38 additions & 8 deletions

File tree

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ CLOUDFLARE_API_TOKEN=MOCK_CLOUDFLARE_API_TOKEN
105105
CLOUDFLARE_VECTORIZE_INDEX=MOCK_CLOUDFLARE_VECTORIZE_INDEX
106106
# Route Workers AI requests through Cloudflare AI Gateway (gateway name/id)
107107
CLOUDFLARE_AI_GATEWAY_ID=MOCK_CLOUDFLARE_AI_GATEWAY_ID
108+
# Optional: route embeddings through a different AI Gateway (for example,
109+
# without guardrails). Falls back to CLOUDFLARE_AI_GATEWAY_ID when omitted.
110+
CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID=MOCK_CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID
108111
# AI Gateway Authenticated Gateway token (sent as `cf-aig-authorization`)
109112
CLOUDFLARE_AI_GATEWAY_AUTH_TOKEN=MOCK_CLOUDFLARE_AI_GATEWAY_AUTH_TOKEN
110113
# Text model used for general generation (defaults to @cf/meta/llama-3.1-8b-instruct)

app/utils/env.server.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ const schemaBase = z.object({
6767
CLOUDFLARE_API_TOKEN: nonEmptyString,
6868
/** AI Gateway "id" is the gateway name you create in Cloudflare. */
6969
CLOUDFLARE_AI_GATEWAY_ID: nonEmptyString,
70+
/**
71+
* Optional embedding-specific AI Gateway id.
72+
* Falls back to `CLOUDFLARE_AI_GATEWAY_ID` when omitted.
73+
*/
74+
CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID: z.string().trim().optional(),
7075
/**
7176
* AI Gateway authenticated gateway token (used as `cf-aig-authorization`).
7277
*/
@@ -154,7 +159,11 @@ type BaseEnvInput = z.input<typeof schemaBase>
154159

155160
export type Env = Omit<
156161
BaseEnv,
157-
'PORT' | 'MOCKS' | 'DATABASE_PATH' | 'FLY_MACHINE_ID'
162+
| 'PORT'
163+
| 'MOCKS'
164+
| 'DATABASE_PATH'
165+
| 'FLY_MACHINE_ID'
166+
| 'CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID'
158167
> & {
159168
PORT: number
160169
MOCKS: boolean
@@ -173,6 +182,11 @@ export type Env = Omit<
173182
* Used to format generated Call Kent transcripts into readable paragraphs.
174183
*/
175184
CLOUDFLARE_AI_CALL_KENT_TRANSCRIPT_FORMAT_MODEL: string
185+
/**
186+
* Embeddings can be routed through a separate gateway (for example, with
187+
* guardrails disabled) without affecting other AI routes.
188+
*/
189+
CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID: string
176190
/** Derived from CLOUDFLARE_ACCOUNT_ID when not explicitly set. */
177191
R2_ENDPOINT: string
178192
}
@@ -241,6 +255,9 @@ export function getEnv(): Env {
241255
allowedActionOrigins: computeAllowedActionOrigins(values),
242256
FLY_MACHINE_ID: values.FLY_MACHINE_ID ?? 'unknown',
243257
R2_ENDPOINT: derivedR2Endpoint,
258+
CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID:
259+
values.CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID ||
260+
values.CLOUDFLARE_AI_GATEWAY_ID,
244261
CLOUDFLARE_AI_CALL_KENT_METADATA_MODEL:
245262
values.CLOUDFLARE_AI_CALL_KENT_METADATA_MODEL ??
246263
values.CLOUDFLARE_AI_TEXT_MODEL,

app/utils/semantic-search.server.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ function getRequiredSemanticSearchEnv() {
178178
return {
179179
accountId: env.CLOUDFLARE_ACCOUNT_ID,
180180
apiToken: env.CLOUDFLARE_API_TOKEN,
181-
gatewayId: env.CLOUDFLARE_AI_GATEWAY_ID,
181+
embeddingGatewayId: env.CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID,
182182
gatewayAuthToken: env.CLOUDFLARE_AI_GATEWAY_AUTH_TOKEN,
183183
indexName: env.CLOUDFLARE_VECTORIZE_INDEX,
184184
embeddingModel: env.CLOUDFLARE_AI_EMBEDDING_MODEL,
@@ -221,19 +221,23 @@ async function cloudflareFetch(
221221
async function getEmbedding({
222222
accountId,
223223
apiToken,
224-
gatewayId,
224+
embeddingGatewayId,
225225
gatewayAuthToken,
226226
model,
227227
text,
228228
}: {
229229
accountId: string
230230
apiToken: string
231-
gatewayId: string
231+
embeddingGatewayId: string
232232
gatewayAuthToken: string
233233
model: string
234234
text: string
235235
}) {
236-
const url = getWorkersAiRunUrl({ accountId, gatewayId, model })
236+
const url = getWorkersAiRunUrl({
237+
accountId,
238+
gatewayId: embeddingGatewayId,
239+
model,
240+
})
237241
const res = await fetch(url, {
238242
method: 'POST',
239243
headers: {
@@ -492,7 +496,7 @@ export async function semanticSearchKCD({
492496
const {
493497
accountId,
494498
apiToken,
495-
gatewayId,
499+
embeddingGatewayId,
496500
gatewayAuthToken,
497501
indexName,
498502
embeddingModel,
@@ -529,7 +533,7 @@ export async function semanticSearchKCD({
529533
const vector = await getEmbedding({
530534
accountId,
531535
apiToken,
532-
gatewayId,
536+
embeddingGatewayId,
533537
gatewayAuthToken,
534538
model: embeddingModel,
535539
text: cleanedQuery,

other/semantic-search/cloudflare.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ function getRequiredEnv(name: string) {
1414
}
1515

1616
export function getCloudflareConfig() {
17+
const defaultGatewayId = getRequiredEnv('CLOUDFLARE_AI_GATEWAY_ID')
18+
const embeddingGatewayId =
19+
process.env.CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID?.trim() || defaultGatewayId
20+
1721
return {
1822
accountId: getRequiredEnv('CLOUDFLARE_ACCOUNT_ID'),
1923
apiToken: getRequiredEnv('CLOUDFLARE_API_TOKEN'),
20-
gatewayId: getRequiredEnv('CLOUDFLARE_AI_GATEWAY_ID'),
24+
// Embedding jobs can use a dedicated gateway without guardrails.
25+
gatewayId: embeddingGatewayId,
2126
gatewayAuthToken: getRequiredEnv('CLOUDFLARE_AI_GATEWAY_AUTH_TOKEN'),
2227
vectorizeIndex: getRequiredEnv('CLOUDFLARE_VECTORIZE_INDEX'),
2328
embeddingModel:

other/semantic-search/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ and shared utilities.
2424
- `CLOUDFLARE_API_TOKEN`
2525
- `CLOUDFLARE_VECTORIZE_INDEX`
2626
- `CLOUDFLARE_AI_EMBEDDING_MODEL` (optional; defaults in code)
27+
- `CLOUDFLARE_AI_EMBEDDING_GATEWAY_ID` (optional; defaults to `CLOUDFLARE_AI_GATEWAY_ID`)
2728

2829
- `R2_BUCKET`
2930

0 commit comments

Comments
 (0)