Skip to content

Commit b3be24a

Browse files
committed
Add Bedrock model client options to Python SDK
1 parent a30258c commit b3be24a

File tree

9 files changed

+407
-111
lines changed

9 files changed

+407
-111
lines changed

src/stagehand/_client.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ._models import FinalRequestOptions
2525
from ._version import __version__
2626
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
27-
from ._exceptions import APIStatusError, StagehandError
27+
from ._exceptions import APIStatusError
2828
from ._base_client import (
2929
DEFAULT_MAX_RETRIES,
3030
SyncAPIClient,
@@ -52,7 +52,7 @@ class Stagehand(SyncAPIClient):
5252
# client options
5353
browserbase_api_key: str | None
5454
browserbase_project_id: str | None
55-
model_api_key: str
55+
model_api_key: str | None
5656

5757
def __init__(
5858
self,
@@ -115,10 +115,6 @@ def __init__(
115115

116116
if model_api_key is None:
117117
model_api_key = os.environ.get("MODEL_API_KEY")
118-
if model_api_key is None:
119-
raise StagehandError(
120-
"The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable"
121-
)
122118
self.model_api_key = model_api_key
123119

124120
self._sea_server: SeaServerManager | None = None
@@ -210,7 +206,7 @@ def _bb_project_id_auth(self) -> dict[str, str]:
210206
@property
211207
def _llm_model_api_key_auth(self) -> dict[str, str]:
212208
model_api_key = self.model_api_key
213-
return {"x-model-api-key": model_api_key}
209+
return {"x-model-api-key": model_api_key} if model_api_key else {}
214210

215211
@property
216212
@override
@@ -273,9 +269,11 @@ def copy(
273269
return self.__class__(
274270
browserbase_api_key=browserbase_api_key or self.browserbase_api_key,
275271
browserbase_project_id=browserbase_project_id or self.browserbase_project_id,
276-
model_api_key=model_api_key or self.model_api_key,
272+
model_api_key=model_api_key if model_api_key is not None else self.model_api_key,
277273
server=server or self._server_mode,
278-
_local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path,
274+
_local_stagehand_binary_path=_local_stagehand_binary_path
275+
if _local_stagehand_binary_path is not None
276+
else self._local_stagehand_binary_path,
279277
local_host=local_host or self._local_host,
280278
local_port=local_port if local_port is not None else self._local_port,
281279
local_headless=local_headless if local_headless is not None else self._local_headless,
@@ -340,7 +338,7 @@ class AsyncStagehand(AsyncAPIClient):
340338
# client options
341339
browserbase_api_key: str | None
342340
browserbase_project_id: str | None
343-
model_api_key: str
341+
model_api_key: str | None
344342

345343
def __init__(
346344
self,
@@ -403,10 +401,6 @@ def __init__(
403401

404402
if model_api_key is None:
405403
model_api_key = os.environ.get("MODEL_API_KEY")
406-
if model_api_key is None:
407-
raise StagehandError(
408-
"The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable"
409-
)
410404
self.model_api_key = model_api_key
411405

412406
self._sea_server: SeaServerManager | None = None
@@ -497,7 +491,7 @@ def _bb_project_id_auth(self) -> dict[str, str]:
497491
@property
498492
def _llm_model_api_key_auth(self) -> dict[str, str]:
499493
model_api_key = self.model_api_key
500-
return {"x-model-api-key": model_api_key}
494+
return {"x-model-api-key": model_api_key} if model_api_key else {}
501495

502496
@property
503497
@override
@@ -560,9 +554,11 @@ def copy(
560554
return self.__class__(
561555
browserbase_api_key=browserbase_api_key or self.browserbase_api_key,
562556
browserbase_project_id=browserbase_project_id or self.browserbase_project_id,
563-
model_api_key=model_api_key or self.model_api_key,
557+
model_api_key=model_api_key if model_api_key is not None else self.model_api_key,
564558
server=server or self._server_mode,
565-
_local_stagehand_binary_path=_local_stagehand_binary_path if _local_stagehand_binary_path is not None else self._local_stagehand_binary_path,
559+
_local_stagehand_binary_path=_local_stagehand_binary_path
560+
if _local_stagehand_binary_path is not None
561+
else self._local_stagehand_binary_path,
566562
local_host=local_host or self._local_host,
567563
local_port=local_port if local_port is not None else self._local_port,
568564
local_headless=local_headless if local_headless is not None else self._local_headless,

src/stagehand/resources/sessions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ def start(
915915
browser: session_start_params.Browser | Omit = omit,
916916
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
917917
browserbase_session_id: str | Omit = omit,
918+
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
918919
dom_settle_timeout_ms: float | Omit = omit,
919920
experimental: bool | Omit = omit,
920921
self_heal: bool | Omit = omit,
@@ -976,6 +977,7 @@ def start(
976977
"browser": browser,
977978
"browserbase_session_create_params": browserbase_session_create_params,
978979
"browserbase_session_id": browserbase_session_id,
980+
"model_client_options": model_client_options,
979981
"dom_settle_timeout_ms": dom_settle_timeout_ms,
980982
"experimental": experimental,
981983
"self_heal": self_heal,
@@ -1867,6 +1869,7 @@ async def start(
18671869
browser: session_start_params.Browser | Omit = omit,
18681870
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
18691871
browserbase_session_id: str | Omit = omit,
1872+
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
18701873
dom_settle_timeout_ms: float | Omit = omit,
18711874
experimental: bool | Omit = omit,
18721875
self_heal: bool | Omit = omit,
@@ -1928,6 +1931,7 @@ async def start(
19281931
"browser": browser,
19291932
"browserbase_session_create_params": browserbase_session_create_params,
19301933
"browserbase_session_id": browserbase_session_id,
1934+
"model_client_options": model_client_options,
19311935
"dom_settle_timeout_ms": dom_settle_timeout_ms,
19321936
"experimental": experimental,
19331937
"self_heal": self_heal,

src/stagehand/resources/sessions_helpers.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from typing import Any
56
from typing_extensions import Literal, override
67

78
import httpx
@@ -27,6 +28,26 @@
2728
from ..types.session_start_response import SessionStartResponse
2829

2930

31+
def _has_explicit_aws_credentials(model_config: dict[str, Any]) -> bool:
32+
return any(model_config.get(key) for key in ("access_key_id", "secret_access_key", "session_token"))
33+
34+
35+
def _build_default_model_config(
36+
*,
37+
model_name: str,
38+
model_client_options: session_start_params.ModelClientOptions | Omit,
39+
fallback_api_key: str | None,
40+
) -> dict[str, Any]:
41+
model_config: dict[str, Any] = {"model_name": model_name}
42+
if isinstance(model_client_options, dict):
43+
model_config.update(model_client_options)
44+
45+
if fallback_api_key and "api_key" not in model_config and not _has_explicit_aws_credentials(model_config):
46+
model_config["api_key"] = fallback_api_key
47+
48+
return model_config
49+
50+
3051
class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):
3152
def __init__(self, sessions: SessionsResourceWithHelpers) -> None: # type: ignore[name-defined]
3253
super().__init__(sessions)
@@ -71,6 +92,7 @@ def start(
7192
browser: session_start_params.Browser | Omit = omit,
7293
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
7394
browserbase_session_id: str | Omit = omit,
95+
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
7496
dom_settle_timeout_ms: float | Omit = omit,
7597
experimental: bool | Omit = omit,
7698
self_heal: bool | Omit = omit,
@@ -89,6 +111,7 @@ def start(
89111
browser=browser,
90112
browserbase_session_create_params=browserbase_session_create_params,
91113
browserbase_session_id=browserbase_session_id,
114+
model_client_options=model_client_options,
92115
dom_settle_timeout_ms=dom_settle_timeout_ms,
93116
experimental=experimental,
94117
self_heal=self_heal,
@@ -101,7 +124,17 @@ def start(
101124
extra_body=extra_body,
102125
timeout=timeout,
103126
)
104-
return Session(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success)
127+
return Session(
128+
self._client,
129+
start_response.data.session_id,
130+
data=start_response.data,
131+
success=start_response.success,
132+
default_model=_build_default_model_config(
133+
model_name=model_name,
134+
model_client_options=model_client_options,
135+
fallback_api_key=self._client.model_api_key,
136+
),
137+
)
105138

106139

107140
class AsyncSessionsResourceWithHelpers(AsyncSessionsResource):
@@ -124,6 +157,7 @@ async def start(
124157
browser: session_start_params.Browser | Omit = omit,
125158
browserbase_session_create_params: session_start_params.BrowserbaseSessionCreateParams | Omit = omit,
126159
browserbase_session_id: str | Omit = omit,
160+
model_client_options: session_start_params.ModelClientOptions | Omit = omit,
127161
dom_settle_timeout_ms: float | Omit = omit,
128162
experimental: bool | Omit = omit,
129163
self_heal: bool | Omit = omit,
@@ -142,6 +176,7 @@ async def start(
142176
browser=browser,
143177
browserbase_session_create_params=browserbase_session_create_params,
144178
browserbase_session_id=browserbase_session_id,
179+
model_client_options=model_client_options,
145180
dom_settle_timeout_ms=dom_settle_timeout_ms,
146181
experimental=experimental,
147182
self_heal=self_heal,
@@ -154,4 +189,14 @@ async def start(
154189
extra_body=extra_body,
155190
timeout=timeout,
156191
)
157-
return AsyncSession(self._client, start_response.data.session_id, data=start_response.data, success=start_response.success)
192+
return AsyncSession(
193+
self._client,
194+
start_response.data.session_id,
195+
data=start_response.data,
196+
success=start_response.success,
197+
default_model=_build_default_model_config(
198+
model_name=model_name,
199+
model_client_options=model_client_options,
200+
fallback_api_key=self._client.model_api_key,
201+
),
202+
)

0 commit comments

Comments
 (0)