-
Notifications
You must be signed in to change notification settings - Fork 67.1k
Expand file tree
/
Copy pathai-search-proxy.ts
More file actions
169 lines (144 loc) · 5.23 KB
/
ai-search-proxy.ts
File metadata and controls
169 lines (144 loc) · 5.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import { Response } from 'express'
import statsd from '@/observability/lib/statsd'
import { fetchStream } from '@/frame/lib/fetch-utils'
import { getHmacWithEpoch } from '@/search/lib/helpers/get-cse-copilot-auth'
import { getCSECopilotSource } from '@/search/lib/helpers/cse-copilot-docs-versions'
import type { ExtendedRequest } from '@/types'
import { handleExternalSearchAnalytics } from '@/search/lib/helpers/external-search-analytics'
// Maximum time (ms) to wait for the initial response from the upstream
// AI search service. Streaming may take longer once the connection is
// established, but the connect + first-byte must complete within this window.
const AI_SEARCH_TIMEOUT_MS = 9_000
export const aiSearchProxy = async (req: ExtendedRequest, res: Response) => {
const { query, version } = req.body ?? {}
const errors = []
// Validate request body
if (!query) {
errors.push({ message: `Missing required key 'query' in request body` })
} else if (typeof query !== 'string') {
errors.push({ message: `Invalid 'query' in request body. Must be a string` })
}
let docsSource = ''
try {
docsSource = getCSECopilotSource(version)
} catch (error: unknown) {
const message = error instanceof Error ? error.message : 'Invalid version'
errors.push({ message })
}
if (errors.length) {
res.status(400).json({ errors })
return
}
// Handle search analytics and client_name validation
const analyticsError = await handleExternalSearchAnalytics(req, 'ai-search')
if (analyticsError) {
res.status(analyticsError.status).json({
errors: [{ message: analyticsError.error }],
})
return
}
const diagnosticTags = [
`version:${version}`.slice(0, 200),
`language:${req.language}`.slice(0, 200),
`queryLength:${query.length}`.slice(0, 200),
]
statsd.increment('ai-search.call', 1, diagnosticTags)
const startTime = Date.now()
let totalChars = 0
const body = {
chat_context: 'docs',
docs_source: docsSource,
query,
stream: true,
}
let reader: ReadableStreamDefaultReader<Uint8Array> | null = null
try {
const response = await fetchStream(
`${process.env.CSE_COPILOT_ENDPOINT}/answers`,
{
method: 'POST',
body: JSON.stringify(body),
headers: {
Authorization: getHmacWithEpoch(),
'Content-Type': 'application/json',
},
},
{
timeout: AI_SEARCH_TIMEOUT_MS,
throwHttpErrors: false,
},
)
if (!response.ok) {
const errorMessage = `Upstream server responded with status code ${response.status}`
console.error(errorMessage)
statsd.increment('ai-search.stream_response_error', 1, diagnosticTags)
res.status(response.status).json({
errors: [{ message: errorMessage }],
upstreamStatus: response.status,
})
return
}
// Set response headers
res.setHeader('Content-Type', 'application/x-ndjson')
res.flushHeaders()
// Stream the response body
if (!response.body) {
res.status(500).json({ errors: [{ message: 'No response body' }] })
return
}
reader = response.body.getReader()
const decoder = new TextDecoder()
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
// Decode chunk and count characters
const chunk = decoder.decode(value, { stream: true })
totalChars += chunk.length
// Write chunk to response
res.write(chunk)
}
// Calculate metrics on stream end
const totalResponseTime = Date.now() - startTime // in ms
const charPerMsRatio = totalResponseTime > 0 ? totalChars / totalResponseTime : 0 // chars per ms
statsd.gauge('ai-search.total_response_time', totalResponseTime, diagnosticTags)
statsd.gauge('ai-search.response_chars_per_ms', charPerMsRatio, diagnosticTags)
statsd.increment('ai-search.success_stream_end', 1, diagnosticTags)
res.end()
} catch (streamError) {
console.error('Error streaming from cse-copilot:', streamError)
statsd.increment('ai-search.stream_error', 1, diagnosticTags)
if (!res.headersSent) {
res.status(500).json({ errors: [{ message: 'Internal server error' }] })
} else {
// Send error message via the stream
const errorMessage = `${JSON.stringify({ errors: [{ message: 'Internal server error' }] })}\n`
res.write(errorMessage)
res.end()
}
} finally {
if (reader) {
reader.releaseLock()
reader = null
}
}
} catch (error) {
const isTimeout = error instanceof Error && error.message.includes('timed out')
if (isTimeout) {
statsd.increment('ai-search.timeout', 1, diagnosticTags)
console.error(`AI search request timed out after ${AI_SEARCH_TIMEOUT_MS}ms`)
res.status(504).json({ errors: [{ message: 'Upstream request timed out' }] })
} else {
statsd.increment('ai-search.route_error', 1, diagnosticTags)
console.error('Error posting /answers to cse-copilot:', error)
res.status(500).json({ errors: [{ message: 'Internal server error' }] })
}
} finally {
// Ensure reader lock is always released
if (reader) {
reader.releaseLock()
}
}
}