|
1 | 1 | import os |
2 | | -from typing import Any, Callable, Literal, Optional |
| 2 | +from typing import Any, Callable, Literal, Optional, Union |
3 | 3 |
|
4 | 4 | from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams |
5 | | -from pydantic import BaseModel, ConfigDict, Field |
| 5 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
6 | 6 |
|
7 | 7 | from stagehand.schemas import AvailableModel |
8 | 8 |
|
@@ -71,7 +71,7 @@ class StagehandConfig(BaseModel): |
71 | 71 | alias="domSettleTimeoutMs", |
72 | 72 | description="Timeout for DOM to settle (in ms)", |
73 | 73 | ) |
74 | | - browserbase_session_create_params: Optional[BrowserbaseSessionCreateParams] = Field( |
| 74 | + browserbase_session_create_params: Optional[Union[BrowserbaseSessionCreateParams, dict[str, Any]]] = Field( |
75 | 75 | None, |
76 | 76 | alias="browserbaseSessionCreateParams", |
77 | 77 | description="Browserbase session create params", |
@@ -117,6 +117,17 @@ class StagehandConfig(BaseModel): |
117 | 117 | ) |
118 | 118 |
|
119 | 119 | model_config = ConfigDict(populate_by_name=True) |
| 120 | + |
| 121 | + @field_validator('browserbase_session_create_params', mode='before') |
| 122 | + @classmethod |
| 123 | + def validate_browserbase_params(cls, v, info): |
| 124 | + """Validate and convert browserbase session create params.""" |
| 125 | + if isinstance(v, dict) and 'project_id' not in v: |
| 126 | + values = info.data |
| 127 | + project_id = values.get('project_id') or values.get('projectId') |
| 128 | + if project_id: |
| 129 | + v = {**v, 'project_id': project_id} |
| 130 | + return v |
120 | 131 |
|
121 | 132 | def with_overrides(self, **overrides) -> "StagehandConfig": |
122 | 133 | """ |
|
0 commit comments