Skip to content

Commit 60469ef

Browse files
committed
Support provider env vars for model API keys
1 parent 0be87c7 commit 60469ef

File tree

7 files changed

+266
-44
lines changed

7 files changed

+266
-44
lines changed

examples/local_browser_playwright_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def main() -> None:
104104
browserbase_api_key=bb_api_key,
105105
browserbase_project_id=bb_project_id,
106106
model_api_key=model_api_key,
107-
local_openai_api_key=model_api_key,
108107
local_ready_timeout_s=30.0,
109108
) as client:
110109
print("⏳ Starting Stagehand session (local server + local browser)...")

examples/local_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def main() -> None:
5151

5252
client = Stagehand(
5353
server="local",
54-
local_openai_api_key=model_key,
5554
local_ready_timeout_s=30.0,
5655
)
5756

examples/local_server_multiregion_browser_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def main() -> None:
7575
browserbase_api_key=bb_api_key,
7676
browserbase_project_id=bb_project_id,
7777
model_api_key=model_api_key,
78-
local_openai_api_key=model_api_key,
7978
local_ready_timeout_s=30.0,
8079
) as client:
8180
print("⏳ Starting Stagehand session (local server + Browserbase browser)...")

src/stagehand/_client.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,36 @@
4747
"AsyncClient",
4848
]
4949

50+
_MODEL_API_KEY_ENV_VARS: tuple[str, ...] = (
51+
"MODEL_API_KEY",
52+
"OPENAI_API_KEY",
53+
"ANTHROPIC_API_KEY",
54+
"GEMINI_API_KEY",
55+
"GOOGLE_GENERATIVE_AI_API_KEY",
56+
"GOOGLE_API_KEY",
57+
"GOOGLE_VERTEX_AI_API_KEY",
58+
"GROQ_API_KEY",
59+
"CEREBRAS_API_KEY",
60+
"TOGETHER_AI_API_KEY",
61+
"MISTRAL_API_KEY",
62+
"DEEPSEEK_API_KEY",
63+
"PERPLEXITY_API_KEY",
64+
"AZURE_API_KEY",
65+
"XAI_API_KEY",
66+
)
67+
68+
69+
def _resolve_model_api_key(model_api_key: str | None) -> str | None:
70+
if model_api_key is not None:
71+
return model_api_key
72+
73+
for env_var in _MODEL_API_KEY_ENV_VARS:
74+
value = os.environ.get(env_var)
75+
if value:
76+
return value
77+
78+
return None
79+
5080

5181
class Stagehand(SyncAPIClient):
5282
# client options
@@ -93,7 +123,7 @@ def __init__(
93123
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
94124
- `browserbase_api_key` from `BROWSERBASE_API_KEY`
95125
- `browserbase_project_id` from `BROWSERBASE_PROJECT_ID`
96-
- `model_api_key` from `MODEL_API_KEY`
126+
- `model_api_key` from `MODEL_API_KEY` or a recognized provider API key env var
97127
"""
98128
self._server_mode: Literal["remote", "local"] = server
99129
self._local_stagehand_binary_path = _local_stagehand_binary_path
@@ -113,11 +143,11 @@ def __init__(
113143
self.browserbase_api_key = browserbase_api_key
114144
self.browserbase_project_id = browserbase_project_id
115145

116-
if model_api_key is None:
117-
model_api_key = os.environ.get("MODEL_API_KEY")
146+
model_api_key = _resolve_model_api_key(model_api_key)
118147
if model_api_key is None:
119148
raise StagehandError(
120-
"The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable"
149+
"The model_api_key client option must be set either by passing model_api_key to the client "
150+
f"or by setting one of the supported environment variables: {', '.join(_MODEL_API_KEY_ENV_VARS)}"
121151
)
122152
self.model_api_key = model_api_key
123153

@@ -127,14 +157,14 @@ def __init__(
127157
if base_url is None:
128158
base_url = "http://127.0.0.1"
129159

130-
openai_api_key = local_openai_api_key or os.environ.get("OPENAI_API_KEY") or model_api_key
160+
local_model_api_key = local_openai_api_key or model_api_key
131161
self._sea_server = SeaServerManager(
132162
config=SeaServerConfig(
133163
host=local_host,
134164
port=local_port,
135165
headless=local_headless,
136166
ready_timeout_s=local_ready_timeout_s,
137-
openai_api_key=openai_api_key,
167+
model_api_key=local_model_api_key,
138168
chrome_path=local_chrome_path,
139169
shutdown_on_close=local_shutdown_on_close,
140170
),
@@ -381,7 +411,7 @@ def __init__(
381411
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
382412
- `browserbase_api_key` from `BROWSERBASE_API_KEY`
383413
- `browserbase_project_id` from `BROWSERBASE_PROJECT_ID`
384-
- `model_api_key` from `MODEL_API_KEY`
414+
- `model_api_key` from `MODEL_API_KEY` or a recognized provider API key env var
385415
"""
386416
self._server_mode: Literal["remote", "local"] = server
387417
self._local_stagehand_binary_path = _local_stagehand_binary_path
@@ -401,11 +431,11 @@ def __init__(
401431
self.browserbase_api_key = browserbase_api_key
402432
self.browserbase_project_id = browserbase_project_id
403433

404-
if model_api_key is None:
405-
model_api_key = os.environ.get("MODEL_API_KEY")
434+
model_api_key = _resolve_model_api_key(model_api_key)
406435
if model_api_key is None:
407436
raise StagehandError(
408-
"The model_api_key client option must be set either by passing model_api_key to the client or by setting the MODEL_API_KEY environment variable"
437+
"The model_api_key client option must be set either by passing model_api_key to the client "
438+
f"or by setting one of the supported environment variables: {', '.join(_MODEL_API_KEY_ENV_VARS)}"
409439
)
410440
self.model_api_key = model_api_key
411441

@@ -414,14 +444,14 @@ def __init__(
414444
if base_url is None:
415445
base_url = "http://127.0.0.1"
416446

417-
openai_api_key = local_openai_api_key or os.environ.get("OPENAI_API_KEY") or model_api_key
447+
local_model_api_key = local_openai_api_key or model_api_key
418448
self._sea_server = SeaServerManager(
419449
config=SeaServerConfig(
420450
host=local_host,
421451
port=local_port,
422452
headless=local_headless,
423453
ready_timeout_s=local_ready_timeout_s,
424-
openai_api_key=openai_api_key,
454+
model_api_key=local_model_api_key,
425455
chrome_path=local_chrome_path,
426456
shutdown_on_close=local_shutdown_on_close,
427457
),

src/stagehand/lib/sea_server.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SeaServerConfig:
2424
port: int
2525
headless: bool
2626
ready_timeout_s: float
27-
openai_api_key: str | None
27+
model_api_key: str | None
2828
chrome_path: str | None
2929
shutdown_on_close: bool
3030

@@ -118,6 +118,22 @@ def __init__(
118118
def base_url(self) -> str | None:
119119
return self._base_url
120120

121+
def _build_process_env(self, *, port: int) -> dict[str, str]:
122+
proc_env = dict(os.environ)
123+
# Defaults that make the server boot under SEA (avoid pino-pretty transport)
124+
proc_env.setdefault("NODE_ENV", "production")
125+
# Server package expects BB_ENV to be set (see packages/server/src/lib/env.ts)
126+
proc_env.setdefault("BB_ENV", "local")
127+
proc_env["HOST"] = self._config.host
128+
proc_env["PORT"] = str(port)
129+
proc_env["HEADLESS"] = "true" if self._config.headless else "false"
130+
if self._config.model_api_key:
131+
proc_env["MODEL_API_KEY"] = self._config.model_api_key
132+
if self._config.chrome_path:
133+
proc_env["CHROME_PATH"] = self._config.chrome_path
134+
proc_env["LIGHTHOUSE_CHROMIUM_PATH"] = self._config.chrome_path
135+
return proc_env
136+
121137
def ensure_running_sync(self) -> str:
122138
with self._lock:
123139
if self._proc is not None and self._proc.poll() is None and self._base_url is not None:
@@ -169,20 +185,7 @@ def _start_sync(self) -> tuple[str, subprocess.Popen[bytes]]:
169185

170186
port = _pick_free_port(self._config.host) if self._config.port == 0 else self._config.port
171187
base_url = _build_base_url(host=self._config.host, port=port)
172-
173-
proc_env = dict(os.environ)
174-
# Defaults that make the server boot under SEA (avoid pino-pretty transport)
175-
proc_env.setdefault("NODE_ENV", "production")
176-
# Server package expects BB_ENV to be set (see packages/server/src/lib/env.ts)
177-
proc_env.setdefault("BB_ENV", "local")
178-
proc_env["HOST"] = self._config.host
179-
proc_env["PORT"] = str(port)
180-
proc_env["HEADLESS"] = "true" if self._config.headless else "false"
181-
if self._config.openai_api_key:
182-
proc_env["OPENAI_API_KEY"] = self._config.openai_api_key
183-
if self._config.chrome_path:
184-
proc_env["CHROME_PATH"] = self._config.chrome_path
185-
proc_env["LIGHTHOUSE_CHROMIUM_PATH"] = self._config.chrome_path
188+
proc_env = self._build_process_env(port=port)
186189

187190
preexec_fn = None
188191
creationflags = 0
@@ -221,18 +224,7 @@ async def _start_async(self) -> tuple[str, subprocess.Popen[bytes]]:
221224

222225
port = _pick_free_port(self._config.host) if self._config.port == 0 else self._config.port
223226
base_url = _build_base_url(host=self._config.host, port=port)
224-
225-
proc_env = dict(os.environ)
226-
proc_env.setdefault("NODE_ENV", "production")
227-
proc_env.setdefault("BB_ENV", "local")
228-
proc_env["HOST"] = self._config.host
229-
proc_env["PORT"] = str(port)
230-
proc_env["HEADLESS"] = "true" if self._config.headless else "false"
231-
if self._config.openai_api_key:
232-
proc_env["OPENAI_API_KEY"] = self._config.openai_api_key
233-
if self._config.chrome_path:
234-
proc_env["CHROME_PATH"] = self._config.chrome_path
235-
proc_env["LIGHTHOUSE_CHROMIUM_PATH"] = self._config.chrome_path
227+
proc_env = self._build_process_env(port=port)
236228

237229
preexec_fn = None
238230
creationflags = 0

tests/test_client.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic import ValidationError
2121

2222
from stagehand import Stagehand, AsyncStagehand, APIResponseValidationError
23+
from stagehand._client import _MODEL_API_KEY_ENV_VARS
2324
from stagehand._types import Omit
2425
from stagehand._utils import asyncify
2526
from stagehand._models import BaseModel, FinalRequestOptions
@@ -45,6 +46,10 @@
4546
model_api_key = "My Model API Key"
4647

4748

49+
def _omit_model_api_key_env_vars() -> dict[str, Omit]:
50+
return {name: Omit() for name in _MODEL_API_KEY_ENV_VARS}
51+
52+
4853
def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
4954
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
5055
url = httpx.URL(request.url)
@@ -469,7 +474,7 @@ def test_validate_headers(self) -> None:
469474
**{
470475
"BROWSERBASE_API_KEY": Omit(),
471476
"BROWSERBASE_PROJECT_ID": Omit(),
472-
"MODEL_API_KEY": Omit(),
477+
**_omit_model_api_key_env_vars(),
473478
}
474479
):
475480
client2 = Stagehand(
@@ -481,6 +486,35 @@ def test_validate_headers(self) -> None:
481486
)
482487
client2.sessions.start(model_name="openai/gpt-5-nano")
483488

489+
def test_model_api_key_falls_back_to_openai_env(self) -> None:
490+
with update_env(
491+
MODEL_API_KEY=Omit(),
492+
OPENAI_API_KEY="openai-key",
493+
):
494+
client = Stagehand(
495+
base_url=base_url,
496+
browserbase_api_key=browserbase_api_key,
497+
browserbase_project_id=browserbase_project_id,
498+
model_api_key=None,
499+
)
500+
501+
assert client.model_api_key == "openai-key"
502+
503+
def test_model_api_key_falls_back_to_gemini_env(self) -> None:
504+
with update_env(
505+
MODEL_API_KEY=Omit(),
506+
OPENAI_API_KEY=Omit(),
507+
GEMINI_API_KEY="gemini-key",
508+
):
509+
client = Stagehand(
510+
base_url=base_url,
511+
browserbase_api_key=browserbase_api_key,
512+
browserbase_project_id=browserbase_project_id,
513+
model_api_key=None,
514+
)
515+
516+
assert client.model_api_key == "gemini-key"
517+
484518
def test_default_query_option(self) -> None:
485519
client = Stagehand(
486520
base_url=base_url,
@@ -1517,7 +1551,7 @@ def test_validate_headers(self) -> None:
15171551
**{
15181552
"BROWSERBASE_API_KEY": Omit(),
15191553
"BROWSERBASE_PROJECT_ID": Omit(),
1520-
"MODEL_API_KEY": Omit(),
1554+
**_omit_model_api_key_env_vars(),
15211555
}
15221556
):
15231557
client2 = AsyncStagehand(
@@ -1529,6 +1563,35 @@ def test_validate_headers(self) -> None:
15291563
)
15301564
_ = client2
15311565

1566+
async def test_model_api_key_falls_back_to_openai_env(self) -> None:
1567+
with update_env(
1568+
MODEL_API_KEY=Omit(),
1569+
OPENAI_API_KEY="openai-key",
1570+
):
1571+
client = AsyncStagehand(
1572+
base_url=base_url,
1573+
browserbase_api_key=browserbase_api_key,
1574+
browserbase_project_id=browserbase_project_id,
1575+
model_api_key=None,
1576+
)
1577+
1578+
assert client.model_api_key == "openai-key"
1579+
1580+
async def test_model_api_key_falls_back_to_gemini_env(self) -> None:
1581+
with update_env(
1582+
MODEL_API_KEY=Omit(),
1583+
OPENAI_API_KEY=Omit(),
1584+
GEMINI_API_KEY="gemini-key",
1585+
):
1586+
client = AsyncStagehand(
1587+
base_url=base_url,
1588+
browserbase_api_key=browserbase_api_key,
1589+
browserbase_project_id=browserbase_project_id,
1590+
model_api_key=None,
1591+
)
1592+
1593+
assert client.model_api_key == "gemini-key"
1594+
15321595
async def test_default_query_option(self) -> None:
15331596
client = AsyncStagehand(
15341597
base_url=base_url,

0 commit comments

Comments
 (0)