Skip to content

Commit 7572751

Browse files
feat(client): support proxy authentication
1 parent 5f27b6b commit 7572751

5 files changed

Lines changed: 257 additions & 85 deletions

File tree

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,21 @@ StagehandClient client = StagehandOkHttpClient.builder()
636636
.build();
637637
```
638638

639+
If the proxy responds with `407 Proxy Authentication Required`, supply credentials by also configuring `proxyAuthenticator`:
640+
641+
```java
642+
import com.browserbase.api.client.StagehandClient;
643+
import com.browserbase.api.client.okhttp.StagehandOkHttpClient;
644+
import com.browserbase.api.core.http.ProxyAuthenticator;
645+
646+
StagehandClient client = StagehandOkHttpClient.builder()
647+
.fromEnv()
648+
.proxy(...)
649+
// Or a custom implementation of `ProxyAuthenticator`.
650+
.proxyAuthenticator(ProxyAuthenticator.basic("username", "password"))
651+
.build();
652+
```
653+
639654
### Connection pooling
640655

641656
To customize the underlying OkHttp connection pool, configure the client using the `maxIdleConnections` and `keepAliveDuration` methods:

stagehand-java-client-okhttp/src/main/kotlin/com/browserbase/api/client/okhttp/OkHttpClient.kt

Lines changed: 149 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ import com.browserbase.api.core.http.HttpMethod
88
import com.browserbase.api.core.http.HttpRequest
99
import com.browserbase.api.core.http.HttpRequestBody
1010
import com.browserbase.api.core.http.HttpResponse
11+
import com.browserbase.api.core.http.ProxyAuthenticator
1112
import com.browserbase.api.errors.StagehandIoException
1213
import java.io.IOException
1314
import java.io.InputStream
15+
import java.io.OutputStream
1416
import java.net.Proxy
1517
import java.time.Duration
1618
import java.util.concurrent.CancellationException
@@ -20,10 +22,12 @@ import java.util.concurrent.TimeUnit
2022
import javax.net.ssl.HostnameVerifier
2123
import javax.net.ssl.SSLSocketFactory
2224
import javax.net.ssl.X509TrustManager
25+
import kotlin.jvm.optionals.getOrNull
2326
import okhttp3.Call
2427
import okhttp3.Callback
2528
import okhttp3.ConnectionPool
2629
import okhttp3.Dispatcher
30+
import okhttp3.HttpUrl
2731
import okhttp3.HttpUrl.Companion.toHttpUrl
2832
import okhttp3.MediaType
2933
import okhttp3.MediaType.Companion.toMediaType
@@ -33,6 +37,8 @@ import okhttp3.RequestBody.Companion.toRequestBody
3337
import okhttp3.Response
3438
import okhttp3.logging.HttpLoggingInterceptor
3539
import okio.BufferedSink
40+
import okio.buffer
41+
import okio.sink
3642

3743
class OkHttpClient
3844
internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClient) : HttpClient {
@@ -41,7 +47,7 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
4147
val call = newCall(request, requestOptions)
4248

4349
return try {
44-
call.execute().toResponse()
50+
call.execute().toHttpResponse()
4551
} catch (e: IOException) {
4652
throw StagehandIoException("Request failed", e)
4753
} finally {
@@ -59,7 +65,7 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
5965
call.enqueue(
6066
object : Callback {
6167
override fun onResponse(call: Call, response: Response) {
62-
future.complete(response.toResponse())
68+
future.complete(response.toHttpResponse())
6369
}
6470

6571
override fun onFailure(call: Call, e: IOException) {
@@ -115,89 +121,6 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
115121
return client.newCall(request.toRequest(client))
116122
}
117123

118-
private fun HttpRequest.toRequest(client: okhttp3.OkHttpClient): Request {
119-
var body: RequestBody? = body?.toRequestBody()
120-
if (body == null && requiresBody(method)) {
121-
body = "".toRequestBody()
122-
}
123-
124-
val builder = Request.Builder().url(toUrl()).method(method.name, body)
125-
headers.names().forEach { name ->
126-
headers.values(name).forEach { builder.addHeader(name, it) }
127-
}
128-
129-
if (
130-
!headers.names().contains("X-Stainless-Read-Timeout") && client.readTimeoutMillis != 0
131-
) {
132-
builder.addHeader(
133-
"X-Stainless-Read-Timeout",
134-
Duration.ofMillis(client.readTimeoutMillis.toLong()).seconds.toString(),
135-
)
136-
}
137-
if (!headers.names().contains("X-Stainless-Timeout") && client.callTimeoutMillis != 0) {
138-
builder.addHeader(
139-
"X-Stainless-Timeout",
140-
Duration.ofMillis(client.callTimeoutMillis.toLong()).seconds.toString(),
141-
)
142-
}
143-
144-
return builder.build()
145-
}
146-
147-
/** `OkHttpClient` always requires a request body for some methods. */
148-
private fun requiresBody(method: HttpMethod): Boolean =
149-
when (method) {
150-
HttpMethod.POST,
151-
HttpMethod.PUT,
152-
HttpMethod.PATCH -> true
153-
else -> false
154-
}
155-
156-
private fun HttpRequest.toUrl(): String {
157-
val builder = baseUrl.toHttpUrl().newBuilder()
158-
pathSegments.forEach(builder::addPathSegment)
159-
queryParams.keys().forEach { key ->
160-
queryParams.values(key).forEach { builder.addQueryParameter(key, it) }
161-
}
162-
163-
return builder.toString()
164-
}
165-
166-
private fun HttpRequestBody.toRequestBody(): RequestBody {
167-
val mediaType = contentType()?.toMediaType()
168-
val length = contentLength()
169-
170-
return object : RequestBody() {
171-
override fun contentType(): MediaType? = mediaType
172-
173-
override fun contentLength(): Long = length
174-
175-
override fun isOneShot(): Boolean = !repeatable()
176-
177-
override fun writeTo(sink: BufferedSink) = writeTo(sink.outputStream())
178-
}
179-
}
180-
181-
private fun Response.toResponse(): HttpResponse {
182-
val headers = headers.toHeaders()
183-
184-
return object : HttpResponse {
185-
override fun statusCode(): Int = code
186-
187-
override fun headers(): Headers = headers
188-
189-
override fun body(): InputStream = body!!.byteStream()
190-
191-
override fun close() = body!!.close()
192-
}
193-
}
194-
195-
private fun okhttp3.Headers.toHeaders(): Headers {
196-
val headersBuilder = Headers.builder()
197-
forEach { (name, value) -> headersBuilder.put(name, value) }
198-
return headersBuilder.build()
199-
}
200-
201124
companion object {
202125
@JvmStatic fun builder() = Builder()
203126
}
@@ -206,6 +129,7 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
206129

207130
private var timeout: Timeout = Timeout.default()
208131
private var proxy: Proxy? = null
132+
private var proxyAuthenticator: ProxyAuthenticator? = null
209133
private var maxIdleConnections: Int? = null
210134
private var keepAliveDuration: Duration? = null
211135
private var dispatcherExecutorService: ExecutorService? = null
@@ -219,6 +143,10 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
219143

220144
fun proxy(proxy: Proxy?) = apply { this.proxy = proxy }
221145

146+
fun proxyAuthenticator(proxyAuthenticator: ProxyAuthenticator?) = apply {
147+
this.proxyAuthenticator = proxyAuthenticator
148+
}
149+
222150
/**
223151
* Sets the maximum number of idle connections kept by the underlying [ConnectionPool].
224152
*
@@ -268,6 +196,19 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
268196
.callTimeout(timeout.request())
269197
.proxy(proxy)
270198
.apply {
199+
proxyAuthenticator?.let { auth ->
200+
proxyAuthenticator { route, response ->
201+
auth
202+
.authenticate(
203+
route?.proxy ?: Proxy.NO_PROXY,
204+
response.request.toHttpRequest(),
205+
response.toHttpResponse(),
206+
)
207+
.getOrNull()
208+
?.toRequest(client = null)
209+
}
210+
}
211+
271212
dispatcherExecutorService?.let { dispatcher(Dispatcher(it)) }
272213

273214
val maxIdleConnections = maxIdleConnections
@@ -307,3 +248,126 @@ internal constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClie
307248
)
308249
}
309250
}
251+
252+
private fun HttpRequest.toRequest(client: okhttp3.OkHttpClient?): Request {
253+
var body: RequestBody? = body?.toRequestBody()
254+
if (body == null && requiresBody(method)) {
255+
body = "".toRequestBody()
256+
}
257+
258+
val builder = Request.Builder().url(toUrl()).method(method.name, body)
259+
headers.names().forEach { name -> headers.values(name).forEach { builder.addHeader(name, it) } }
260+
261+
if (client != null) {
262+
if (
263+
!headers.names().contains("X-Stainless-Read-Timeout") && client.readTimeoutMillis != 0
264+
) {
265+
builder.addHeader(
266+
"X-Stainless-Read-Timeout",
267+
Duration.ofMillis(client.readTimeoutMillis.toLong()).seconds.toString(),
268+
)
269+
}
270+
if (!headers.names().contains("X-Stainless-Timeout") && client.callTimeoutMillis != 0) {
271+
builder.addHeader(
272+
"X-Stainless-Timeout",
273+
Duration.ofMillis(client.callTimeoutMillis.toLong()).seconds.toString(),
274+
)
275+
}
276+
}
277+
278+
return builder.build()
279+
}
280+
281+
/** `OkHttpClient` always requires a request body for some methods. */
282+
private fun requiresBody(method: HttpMethod): Boolean =
283+
when (method) {
284+
HttpMethod.POST,
285+
HttpMethod.PUT,
286+
HttpMethod.PATCH -> true
287+
else -> false
288+
}
289+
290+
private fun HttpRequest.toUrl(): String {
291+
val builder = baseUrl.toHttpUrl().newBuilder()
292+
pathSegments.forEach(builder::addPathSegment)
293+
queryParams.keys().forEach { key ->
294+
queryParams.values(key).forEach { builder.addQueryParameter(key, it) }
295+
}
296+
297+
return builder.toString()
298+
}
299+
300+
private fun HttpRequestBody.toRequestBody(): RequestBody {
301+
val mediaType = contentType()?.toMediaType()
302+
val length = contentLength()
303+
304+
return object : RequestBody() {
305+
override fun contentType(): MediaType? = mediaType
306+
307+
override fun contentLength(): Long = length
308+
309+
override fun isOneShot(): Boolean = !repeatable()
310+
311+
override fun writeTo(sink: BufferedSink) = writeTo(sink.outputStream())
312+
}
313+
}
314+
315+
private fun Request.toHttpRequest(): HttpRequest {
316+
val builder = HttpRequest.builder().method(HttpMethod.valueOf(method)).baseUrl(url.toBaseUrl())
317+
url.pathSegments.forEach(builder::addPathSegment)
318+
url.queryParameterNames.forEach { name ->
319+
url.queryParameterValues(name).filterNotNull().forEach { builder.putQueryParam(name, it) }
320+
}
321+
headers.forEach { (name, value) -> builder.putHeader(name, value) }
322+
body?.let { builder.body(it.toHttpRequestBody()) }
323+
return builder.build()
324+
}
325+
326+
private fun HttpUrl.toBaseUrl(): String = buildString {
327+
append(scheme).append("://").append(host)
328+
if (port != HttpUrl.defaultPort(scheme)) {
329+
append(":").append(port)
330+
}
331+
}
332+
333+
private fun RequestBody.toHttpRequestBody(): HttpRequestBody {
334+
val mediaType = contentType()?.toString()
335+
val length = contentLength()
336+
val isOneShot = isOneShot()
337+
val source = this
338+
return object : HttpRequestBody {
339+
override fun contentType(): String? = mediaType
340+
341+
override fun contentLength(): Long = length
342+
343+
override fun repeatable(): Boolean = !isOneShot
344+
345+
override fun writeTo(outputStream: OutputStream) {
346+
val sink = outputStream.sink().buffer()
347+
source.writeTo(sink)
348+
sink.flush()
349+
}
350+
351+
override fun close() {}
352+
}
353+
}
354+
355+
private fun Response.toHttpResponse(): HttpResponse {
356+
val headers = headers.toHeaders()
357+
358+
return object : HttpResponse {
359+
override fun statusCode(): Int = code
360+
361+
override fun headers(): Headers = headers
362+
363+
override fun body(): InputStream = body!!.byteStream()
364+
365+
override fun close() = body!!.close()
366+
}
367+
}
368+
369+
private fun okhttp3.Headers.toHeaders(): Headers {
370+
val headersBuilder = Headers.builder()
371+
forEach { (name, value) -> headersBuilder.put(name, value) }
372+
return headersBuilder.build()
373+
}

stagehand-java-client-okhttp/src/main/kotlin/com/browserbase/api/client/okhttp/StagehandOkHttpClient.kt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import com.browserbase.api.core.Timeout
1010
import com.browserbase.api.core.http.AsyncStreamResponse
1111
import com.browserbase.api.core.http.Headers
1212
import com.browserbase.api.core.http.HttpClient
13+
import com.browserbase.api.core.http.ProxyAuthenticator
1314
import com.browserbase.api.core.http.QueryParams
1415
import com.browserbase.api.core.jsonMapper
1516
import com.fasterxml.jackson.databind.json.JsonMapper
@@ -49,6 +50,7 @@ class StagehandOkHttpClient private constructor() {
4950
private var clientOptions: ClientOptions.Builder = ClientOptions.builder()
5051
private var dispatcherExecutorService: ExecutorService? = null
5152
private var proxy: Proxy? = null
53+
private var proxyAuthenticator: ProxyAuthenticator? = null
5254
private var maxIdleConnections: Int? = null
5355
private var keepAliveDuration: Duration? = null
5456
private var sslSocketFactory: SSLSocketFactory? = null
@@ -79,6 +81,20 @@ class StagehandOkHttpClient private constructor() {
7981
/** Alias for calling [Builder.proxy] with `proxy.orElse(null)`. */
8082
fun proxy(proxy: Optional<Proxy>) = proxy(proxy.getOrNull())
8183

84+
/**
85+
* Provides credentials when an HTTP proxy responds with `407 Proxy Authentication
86+
* Required`.
87+
*/
88+
fun proxyAuthenticator(proxyAuthenticator: ProxyAuthenticator?) = apply {
89+
this.proxyAuthenticator = proxyAuthenticator
90+
}
91+
92+
/**
93+
* Alias for calling [Builder.proxyAuthenticator] with `proxyAuthenticator.orElse(null)`.
94+
*/
95+
fun proxyAuthenticator(proxyAuthenticator: Optional<ProxyAuthenticator>) =
96+
proxyAuthenticator(proxyAuthenticator.getOrNull())
97+
8298
/**
8399
* The maximum number of idle connections kept by the underlying OkHttp connection pool.
84100
*
@@ -386,6 +402,7 @@ class StagehandOkHttpClient private constructor() {
386402
OkHttpClient.builder()
387403
.timeout(clientOptions.timeout())
388404
.proxy(proxy)
405+
.proxyAuthenticator(proxyAuthenticator)
389406
.maxIdleConnections(maxIdleConnections)
390407
.keepAliveDuration(keepAliveDuration)
391408
.dispatcherExecutorService(dispatcherExecutorService)

0 commit comments

Comments
 (0)