@@ -8,9 +8,11 @@ import com.browserbase.api.core.http.HttpMethod
88import com.browserbase.api.core.http.HttpRequest
99import com.browserbase.api.core.http.HttpRequestBody
1010import com.browserbase.api.core.http.HttpResponse
11+ import com.browserbase.api.core.http.ProxyAuthenticator
1112import com.browserbase.api.errors.StagehandIoException
1213import java.io.IOException
1314import java.io.InputStream
15+ import java.io.OutputStream
1416import java.net.Proxy
1517import java.time.Duration
1618import java.util.concurrent.CancellationException
@@ -20,10 +22,12 @@ import java.util.concurrent.TimeUnit
2022import javax.net.ssl.HostnameVerifier
2123import javax.net.ssl.SSLSocketFactory
2224import javax.net.ssl.X509TrustManager
25+ import kotlin.jvm.optionals.getOrNull
2326import okhttp3.Call
2427import okhttp3.Callback
2528import okhttp3.ConnectionPool
2629import okhttp3.Dispatcher
30+ import okhttp3.HttpUrl
2731import okhttp3.HttpUrl.Companion.toHttpUrl
2832import okhttp3.MediaType
2933import okhttp3.MediaType.Companion.toMediaType
@@ -33,6 +37,8 @@ import okhttp3.RequestBody.Companion.toRequestBody
3337import okhttp3.Response
3438import okhttp3.logging.HttpLoggingInterceptor
3539import okio.BufferedSink
40+ import okio.buffer
41+ import okio.sink
3642
3743class OkHttpClient
3844internal constructor (@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClient ) : HttpClient {
@@ -41,7 +47,7 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
4147 val call = newCall(request, requestOptions)
4248
4349 return try {
44- call.execute().toResponse ()
50+ call.execute().toHttpResponse ()
4551 } catch (e: IOException ) {
4652 throw StagehandIoException (" Request failed" , e)
4753 } finally {
@@ -59,7 +65,7 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
5965 call.enqueue(
6066 object : Callback {
6167 override fun onResponse (call : Call , response : Response ) {
62- future.complete(response.toResponse ())
68+ future.complete(response.toHttpResponse ())
6369 }
6470
6571 override fun onFailure (call : Call , e : IOException ) {
@@ -115,89 +121,6 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
115121 return client.newCall(request.toRequest(client))
116122 }
117123
118- private fun HttpRequest.toRequest (client : okhttp3.OkHttpClient ): Request {
119- var body: RequestBody ? = body?.toRequestBody()
120- if (body == null && requiresBody(method)) {
121- body = " " .toRequestBody()
122- }
123-
124- val builder = Request .Builder ().url(toUrl()).method(method.name, body)
125- headers.names().forEach { name ->
126- headers.values(name).forEach { builder.addHeader(name, it) }
127- }
128-
129- if (
130- ! headers.names().contains(" X-Stainless-Read-Timeout" ) && client.readTimeoutMillis != 0
131- ) {
132- builder.addHeader(
133- " X-Stainless-Read-Timeout" ,
134- Duration .ofMillis(client.readTimeoutMillis.toLong()).seconds.toString(),
135- )
136- }
137- if (! headers.names().contains(" X-Stainless-Timeout" ) && client.callTimeoutMillis != 0 ) {
138- builder.addHeader(
139- " X-Stainless-Timeout" ,
140- Duration .ofMillis(client.callTimeoutMillis.toLong()).seconds.toString(),
141- )
142- }
143-
144- return builder.build()
145- }
146-
147- /* * `OkHttpClient` always requires a request body for some methods. */
148- private fun requiresBody (method : HttpMethod ): Boolean =
149- when (method) {
150- HttpMethod .POST ,
151- HttpMethod .PUT ,
152- HttpMethod .PATCH -> true
153- else -> false
154- }
155-
156- private fun HttpRequest.toUrl (): String {
157- val builder = baseUrl.toHttpUrl().newBuilder()
158- pathSegments.forEach(builder::addPathSegment)
159- queryParams.keys().forEach { key ->
160- queryParams.values(key).forEach { builder.addQueryParameter(key, it) }
161- }
162-
163- return builder.toString()
164- }
165-
166- private fun HttpRequestBody.toRequestBody (): RequestBody {
167- val mediaType = contentType()?.toMediaType()
168- val length = contentLength()
169-
170- return object : RequestBody () {
171- override fun contentType (): MediaType ? = mediaType
172-
173- override fun contentLength (): Long = length
174-
175- override fun isOneShot (): Boolean = ! repeatable()
176-
177- override fun writeTo (sink : BufferedSink ) = writeTo(sink.outputStream())
178- }
179- }
180-
181- private fun Response.toResponse (): HttpResponse {
182- val headers = headers.toHeaders()
183-
184- return object : HttpResponse {
185- override fun statusCode (): Int = code
186-
187- override fun headers (): Headers = headers
188-
189- override fun body (): InputStream = body!! .byteStream()
190-
191- override fun close () = body!! .close()
192- }
193- }
194-
195- private fun okhttp3.Headers.toHeaders (): Headers {
196- val headersBuilder = Headers .builder()
197- forEach { (name, value) -> headersBuilder.put(name, value) }
198- return headersBuilder.build()
199- }
200-
201124 companion object {
202125 @JvmStatic fun builder () = Builder ()
203126 }
@@ -206,6 +129,7 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
206129
207130 private var timeout: Timeout = Timeout .default()
208131 private var proxy: Proxy ? = null
132+ private var proxyAuthenticator: ProxyAuthenticator ? = null
209133 private var maxIdleConnections: Int? = null
210134 private var keepAliveDuration: Duration ? = null
211135 private var dispatcherExecutorService: ExecutorService ? = null
@@ -219,6 +143,10 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
219143
220144 fun proxy (proxy : Proxy ? ) = apply { this .proxy = proxy }
221145
146+ fun proxyAuthenticator (proxyAuthenticator : ProxyAuthenticator ? ) = apply {
147+ this .proxyAuthenticator = proxyAuthenticator
148+ }
149+
222150 /* *
223151 * Sets the maximum number of idle connections kept by the underlying [ConnectionPool].
224152 *
@@ -268,6 +196,19 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
268196 .callTimeout(timeout.request())
269197 .proxy(proxy)
270198 .apply {
199+ proxyAuthenticator?.let { auth ->
200+ proxyAuthenticator { route, response ->
201+ auth
202+ .authenticate(
203+ route?.proxy ? : Proxy .NO_PROXY ,
204+ response.request.toHttpRequest(),
205+ response.toHttpResponse(),
206+ )
207+ .getOrNull()
208+ ?.toRequest(client = null )
209+ }
210+ }
211+
271212 dispatcherExecutorService?.let { dispatcher(Dispatcher (it)) }
272213
273214 val maxIdleConnections = maxIdleConnections
@@ -307,3 +248,126 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
307248 )
308249 }
309250}
251+
252+ private fun HttpRequest.toRequest (client : okhttp3.OkHttpClient ? ): Request {
253+ var body: RequestBody ? = body?.toRequestBody()
254+ if (body == null && requiresBody(method)) {
255+ body = " " .toRequestBody()
256+ }
257+
258+ val builder = Request .Builder ().url(toUrl()).method(method.name, body)
259+ headers.names().forEach { name -> headers.values(name).forEach { builder.addHeader(name, it) } }
260+
261+ if (client != null ) {
262+ if (
263+ ! headers.names().contains(" X-Stainless-Read-Timeout" ) && client.readTimeoutMillis != 0
264+ ) {
265+ builder.addHeader(
266+ " X-Stainless-Read-Timeout" ,
267+ Duration .ofMillis(client.readTimeoutMillis.toLong()).seconds.toString(),
268+ )
269+ }
270+ if (! headers.names().contains(" X-Stainless-Timeout" ) && client.callTimeoutMillis != 0 ) {
271+ builder.addHeader(
272+ " X-Stainless-Timeout" ,
273+ Duration .ofMillis(client.callTimeoutMillis.toLong()).seconds.toString(),
274+ )
275+ }
276+ }
277+
278+ return builder.build()
279+ }
280+
281+ /* * `OkHttpClient` always requires a request body for some methods. */
282+ private fun requiresBody (method : HttpMethod ): Boolean =
283+ when (method) {
284+ HttpMethod .POST ,
285+ HttpMethod .PUT ,
286+ HttpMethod .PATCH -> true
287+ else -> false
288+ }
289+
290+ private fun HttpRequest.toUrl (): String {
291+ val builder = baseUrl.toHttpUrl().newBuilder()
292+ pathSegments.forEach(builder::addPathSegment)
293+ queryParams.keys().forEach { key ->
294+ queryParams.values(key).forEach { builder.addQueryParameter(key, it) }
295+ }
296+
297+ return builder.toString()
298+ }
299+
300+ private fun HttpRequestBody.toRequestBody (): RequestBody {
301+ val mediaType = contentType()?.toMediaType()
302+ val length = contentLength()
303+
304+ return object : RequestBody () {
305+ override fun contentType (): MediaType ? = mediaType
306+
307+ override fun contentLength (): Long = length
308+
309+ override fun isOneShot (): Boolean = ! repeatable()
310+
311+ override fun writeTo (sink : BufferedSink ) = writeTo(sink.outputStream())
312+ }
313+ }
314+
315+ private fun Request.toHttpRequest (): HttpRequest {
316+ val builder = HttpRequest .builder().method(HttpMethod .valueOf(method)).baseUrl(url.toBaseUrl())
317+ url.pathSegments.forEach(builder::addPathSegment)
318+ url.queryParameterNames.forEach { name ->
319+ url.queryParameterValues(name).filterNotNull().forEach { builder.putQueryParam(name, it) }
320+ }
321+ headers.forEach { (name, value) -> builder.putHeader(name, value) }
322+ body?.let { builder.body(it.toHttpRequestBody()) }
323+ return builder.build()
324+ }
325+
326+ private fun HttpUrl.toBaseUrl (): String = buildString {
327+ append(scheme).append(" ://" ).append(host)
328+ if (port != HttpUrl .defaultPort(scheme)) {
329+ append(" :" ).append(port)
330+ }
331+ }
332+
333+ private fun RequestBody.toHttpRequestBody (): HttpRequestBody {
334+ val mediaType = contentType()?.toString()
335+ val length = contentLength()
336+ val isOneShot = isOneShot()
337+ val source = this
338+ return object : HttpRequestBody {
339+ override fun contentType (): String? = mediaType
340+
341+ override fun contentLength (): Long = length
342+
343+ override fun repeatable (): Boolean = ! isOneShot
344+
345+ override fun writeTo (outputStream : OutputStream ) {
346+ val sink = outputStream.sink().buffer()
347+ source.writeTo(sink)
348+ sink.flush()
349+ }
350+
351+ override fun close () {}
352+ }
353+ }
354+
355+ private fun Response.toHttpResponse (): HttpResponse {
356+ val headers = headers.toHeaders()
357+
358+ return object : HttpResponse {
359+ override fun statusCode (): Int = code
360+
361+ override fun headers (): Headers = headers
362+
363+ override fun body (): InputStream = body!! .byteStream()
364+
365+ override fun close () = body!! .close()
366+ }
367+ }
368+
369+ private fun okhttp3.Headers.toHeaders (): Headers {
370+ val headersBuilder = Headers .builder()
371+ forEach { (name, value) -> headersBuilder.put(name, value) }
372+ return headersBuilder.build()
373+ }
0 commit comments