Skip to content

Commit 5201ec1

Browse files
fix: sanitize endpoint path params
1 parent 0faf1df commit 5201ec1

File tree

4 files changed

+232
-15
lines changed

4 files changed

+232
-15
lines changed

src/stagehand/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._path import path_template as path_template
12
from ._sync import asyncify as asyncify
23
from ._proxy import LazyProxy as LazyProxy
34
from ._utils import (

src/stagehand/_utils/_path.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from typing import (
5+
Any,
6+
Mapping,
7+
Callable,
8+
)
9+
from urllib.parse import quote
10+
11+
# Matches '.' or '..' where each dot is either literal or percent-encoded (%2e / %2E).
12+
_DOT_SEGMENT_RE = re.compile(r"^(?:\.|%2[eE]){1,2}$")
13+
14+
_PLACEHOLDER_RE = re.compile(r"\{(\w+)\}")
15+
16+
17+
def _quote_path_segment_part(value: str) -> str:
18+
"""Percent-encode `value` for use in a URI path segment.
19+
20+
Considers characters not in `pchar` set from RFC 3986 §3.3 to be unsafe.
21+
https://datatracker.ietf.org/doc/html/rfc3986#section-3.3
22+
"""
23+
# quote() already treats unreserved characters (letters, digits, and -._~)
24+
# as safe, so we only need to add sub-delims, ':', and '@'.
25+
# Notably, unlike the default `safe` for quote(), / is unsafe and must be quoted.
26+
return quote(value, safe="!$&'()*+,;=:@")
27+
28+
29+
def _quote_query_part(value: str) -> str:
30+
"""Percent-encode `value` for use in a URI query string.
31+
32+
Considers &, = and characters not in `query` set from RFC 3986 §3.4 to be unsafe.
33+
https://datatracker.ietf.org/doc/html/rfc3986#section-3.4
34+
"""
35+
return quote(value, safe="!$'()*+,;:@/?")
36+
37+
38+
def _quote_fragment_part(value: str) -> str:
39+
"""Percent-encode `value` for use in a URI fragment.
40+
41+
Considers characters not in `fragment` set from RFC 3986 §3.5 to be unsafe.
42+
https://datatracker.ietf.org/doc/html/rfc3986#section-3.5
43+
"""
44+
return quote(value, safe="!$&'()*+,;=:@/?")
45+
46+
47+
def _interpolate(
48+
template: str,
49+
values: Mapping[str, Any],
50+
quoter: Callable[[str], str],
51+
) -> str:
52+
"""Replace {name} placeholders in `template`, quoting each value with `quoter`.
53+
54+
Placeholder names are looked up in `values`.
55+
56+
Raises:
57+
KeyError: If a placeholder is not found in `values`.
58+
"""
59+
# re.split with a capturing group returns alternating
60+
# [text, name, text, name, ..., text] elements.
61+
parts = _PLACEHOLDER_RE.split(template)
62+
63+
for i in range(1, len(parts), 2):
64+
name = parts[i]
65+
if name not in values:
66+
raise KeyError(f"a value for placeholder {{{name}}} was not provided")
67+
val = values[name]
68+
if val is None:
69+
parts[i] = "null"
70+
elif isinstance(val, bool):
71+
parts[i] = "true" if val else "false"
72+
else:
73+
parts[i] = quoter(str(values[name]))
74+
75+
return "".join(parts)
76+
77+
78+
def path_template(template: str, /, **kwargs: Any) -> str:
79+
"""Interpolate {name} placeholders in `template` from keyword arguments.
80+
81+
Args:
82+
template: The template string containing {name} placeholders.
83+
**kwargs: Keyword arguments to interpolate into the template.
84+
85+
Returns:
86+
The template with placeholders interpolated and percent-encoded.
87+
88+
Safe characters for percent-encoding are dependent on the URI component.
89+
Placeholders in path and fragment portions are percent-encoded where the `segment`
90+
and `fragment` sets from RFC 3986 respectively are considered safe.
91+
Placeholders in the query portion are percent-encoded where the `query` set from
92+
RFC 3986 §3.3 is considered safe except for = and & characters.
93+
94+
Raises:
95+
KeyError: If a placeholder is not found in `kwargs`.
96+
ValueError: If resulting path contains /./ or /../ segments (including percent-encoded dot-segments).
97+
"""
98+
# Split the template into path, query, and fragment portions.
99+
fragment_template: str | None = None
100+
query_template: str | None = None
101+
102+
rest = template
103+
if "#" in rest:
104+
rest, fragment_template = rest.split("#", 1)
105+
if "?" in rest:
106+
rest, query_template = rest.split("?", 1)
107+
path_template = rest
108+
109+
# Interpolate each portion with the appropriate quoting rules.
110+
path_result = _interpolate(path_template, kwargs, _quote_path_segment_part)
111+
112+
# Reject dot-segments (. and ..) in the final assembled path. The check
113+
# runs after interpolation so that adjacent placeholders or a mix of static
114+
# text and placeholders that together form a dot-segment are caught.
115+
# Also reject percent-encoded dot-segments to protect against incorrectly
116+
# implemented normalization in servers/proxies.
117+
for segment in path_result.split("/"):
118+
if _DOT_SEGMENT_RE.match(segment):
119+
raise ValueError(f"Constructed path {path_result!r} contains dot-segment {segment!r} which is not allowed")
120+
121+
result = path_result
122+
if query_template is not None:
123+
result += "?" + _interpolate(query_template, kwargs, _quote_query_part)
124+
if fragment_template is not None:
125+
result += "#" + _interpolate(fragment_template, kwargs, _quote_fragment_part)
126+
127+
return result

src/stagehand/resources/sessions.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
session_navigate_params,
1717
)
1818
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
19-
from .._utils import is_given, required_args, maybe_transform, strip_not_given, async_maybe_transform
19+
from .._utils import is_given, path_template, required_args, maybe_transform, strip_not_given, async_maybe_transform
2020
from .._compat import cached_property
2121
from .._resource import SyncAPIResource, AsyncAPIResource
2222
from .._response import (
@@ -212,7 +212,7 @@ def act(
212212
**(extra_headers or {}),
213213
}
214214
return self._post(
215-
f"/v1/sessions/{id}/act",
215+
path_template("/v1/sessions/{id}/act", id=id),
216216
body=maybe_transform(
217217
{
218218
"input": input,
@@ -269,7 +269,7 @@ def end(
269269
**(extra_headers or {}),
270270
}
271271
return self._post(
272-
f"/v1/sessions/{id}/end",
272+
path_template("/v1/sessions/{id}/end", id=id),
273273
options=make_request_options(
274274
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
275275
),
@@ -429,7 +429,7 @@ def execute(
429429
**(extra_headers or {}),
430430
}
431431
return self._post(
432-
f"/v1/sessions/{id}/agentExecute",
432+
path_template("/v1/sessions/{id}/agentExecute", id=id),
433433
body=maybe_transform(
434434
{
435435
"agent_config": agent_config,
@@ -608,7 +608,7 @@ def extract(
608608
**(extra_headers or {}),
609609
}
610610
return self._post(
611-
f"/v1/sessions/{id}/extract",
611+
path_template("/v1/sessions/{id}/extract", id=id),
612612
body=maybe_transform(
613613
{
614614
"frame_id": frame_id,
@@ -676,7 +676,7 @@ def navigate(
676676
**(extra_headers or {}),
677677
}
678678
return self._post(
679-
f"/v1/sessions/{id}/navigate",
679+
path_template("/v1/sessions/{id}/navigate", id=id),
680680
body=maybe_transform(
681681
{
682682
"url": url,
@@ -843,7 +843,7 @@ def observe(
843843
**(extra_headers or {}),
844844
}
845845
return self._post(
846-
f"/v1/sessions/{id}/observe",
846+
path_template("/v1/sessions/{id}/observe", id=id),
847847
body=maybe_transform(
848848
{
849849
"frame_id": frame_id,
@@ -900,7 +900,7 @@ def replay(
900900
**(extra_headers or {}),
901901
}
902902
return self._get(
903-
f"/v1/sessions/{id}/replay",
903+
path_template("/v1/sessions/{id}/replay", id=id),
904904
options=make_request_options(
905905
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
906906
),
@@ -1164,7 +1164,7 @@ async def act(
11641164
**(extra_headers or {}),
11651165
}
11661166
return await self._post(
1167-
f"/v1/sessions/{id}/act",
1167+
path_template("/v1/sessions/{id}/act", id=id),
11681168
body=await async_maybe_transform(
11691169
{
11701170
"input": input,
@@ -1221,7 +1221,7 @@ async def end(
12211221
**(extra_headers or {}),
12221222
}
12231223
return await self._post(
1224-
f"/v1/sessions/{id}/end",
1224+
path_template("/v1/sessions/{id}/end", id=id),
12251225
options=make_request_options(
12261226
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
12271227
),
@@ -1381,7 +1381,7 @@ async def execute(
13811381
**(extra_headers or {}),
13821382
}
13831383
return await self._post(
1384-
f"/v1/sessions/{id}/agentExecute",
1384+
path_template("/v1/sessions/{id}/agentExecute", id=id),
13851385
body=await async_maybe_transform(
13861386
{
13871387
"agent_config": agent_config,
@@ -1560,7 +1560,7 @@ async def extract(
15601560
**(extra_headers or {}),
15611561
}
15621562
return await self._post(
1563-
f"/v1/sessions/{id}/extract",
1563+
path_template("/v1/sessions/{id}/extract", id=id),
15641564
body=await async_maybe_transform(
15651565
{
15661566
"frame_id": frame_id,
@@ -1628,7 +1628,7 @@ async def navigate(
16281628
**(extra_headers or {}),
16291629
}
16301630
return await self._post(
1631-
f"/v1/sessions/{id}/navigate",
1631+
path_template("/v1/sessions/{id}/navigate", id=id),
16321632
body=await async_maybe_transform(
16331633
{
16341634
"url": url,
@@ -1795,7 +1795,7 @@ async def observe(
17951795
**(extra_headers or {}),
17961796
}
17971797
return await self._post(
1798-
f"/v1/sessions/{id}/observe",
1798+
path_template("/v1/sessions/{id}/observe", id=id),
17991799
body=await async_maybe_transform(
18001800
{
18011801
"frame_id": frame_id,
@@ -1852,7 +1852,7 @@ async def replay(
18521852
**(extra_headers or {}),
18531853
}
18541854
return await self._get(
1855-
f"/v1/sessions/{id}/replay",
1855+
path_template("/v1/sessions/{id}/replay", id=id),
18561856
options=make_request_options(
18571857
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
18581858
),

tests/test_utils/test_path.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pytest
6+
7+
from stagehand._utils._path import path_template
8+
9+
10+
@pytest.mark.parametrize(
11+
"template, kwargs, expected",
12+
[
13+
("/v1/{id}", dict(id="abc"), "/v1/abc"),
14+
("/v1/{a}/{b}", dict(a="x", b="y"), "/v1/x/y"),
15+
("/v1/{a}{b}/path/{c}?val={d}#{e}", dict(a="x", b="y", c="z", d="u", e="v"), "/v1/xy/path/z?val=u#v"),
16+
("/{w}/{w}", dict(w="echo"), "/echo/echo"),
17+
("/v1/static", {}, "/v1/static"),
18+
("", {}, ""),
19+
("/v1/?q={n}&count=10", dict(n=42), "/v1/?q=42&count=10"),
20+
("/v1/{v}", dict(v=None), "/v1/null"),
21+
("/v1/{v}", dict(v=True), "/v1/true"),
22+
("/v1/{v}", dict(v=False), "/v1/false"),
23+
("/v1/{v}", dict(v=".hidden"), "/v1/.hidden"), # dot prefix ok
24+
("/v1/{v}", dict(v="file.txt"), "/v1/file.txt"), # dot in middle ok
25+
("/v1/{v}", dict(v="..."), "/v1/..."), # triple dot ok
26+
("/v1/{a}{b}", dict(a=".", b="txt"), "/v1/.txt"), # dot var combining with adjacent to be ok
27+
("/items?q={v}#{f}", dict(v=".", f=".."), "/items?q=.#.."), # dots in query/fragment are fine
28+
(
29+
"/v1/{a}?query={b}",
30+
dict(a="../../other/endpoint", b="a&bad=true"),
31+
"/v1/..%2F..%2Fother%2Fendpoint?query=a%26bad%3Dtrue",
32+
),
33+
("/v1/{val}", dict(val="a/b/c"), "/v1/a%2Fb%2Fc"),
34+
("/v1/{val}", dict(val="a/b/c?query=value"), "/v1/a%2Fb%2Fc%3Fquery=value"),
35+
("/v1/{val}", dict(val="a/b/c?query=value&bad=true"), "/v1/a%2Fb%2Fc%3Fquery=value&bad=true"),
36+
("/v1/{val}", dict(val="%20"), "/v1/%2520"), # escapes escape sequences in input
37+
# Query: slash and ? are safe, # is not
38+
("/items?q={v}", dict(v="a/b"), "/items?q=a/b"),
39+
("/items?q={v}", dict(v="a?b"), "/items?q=a?b"),
40+
("/items?q={v}", dict(v="a#b"), "/items?q=a%23b"),
41+
("/items?q={v}", dict(v="a b"), "/items?q=a%20b"),
42+
# Fragment: slash and ? are safe
43+
("/docs#{v}", dict(v="a/b"), "/docs#a/b"),
44+
("/docs#{v}", dict(v="a?b"), "/docs#a?b"),
45+
# Path: slash, ? and # are all encoded
46+
("/v1/{v}", dict(v="a/b"), "/v1/a%2Fb"),
47+
("/v1/{v}", dict(v="a?b"), "/v1/a%3Fb"),
48+
("/v1/{v}", dict(v="a#b"), "/v1/a%23b"),
49+
# same var encoded differently by component
50+
(
51+
"/v1/{v}?q={v}#{v}",
52+
dict(v="a/b?c#d"),
53+
"/v1/a%2Fb%3Fc%23d?q=a/b?c%23d#a/b?c%23d",
54+
),
55+
("/v1/{val}", dict(val="x?admin=true"), "/v1/x%3Fadmin=true"), # query injection
56+
("/v1/{val}", dict(val="x#admin"), "/v1/x%23admin"), # fragment injection
57+
],
58+
)
59+
def test_interpolation(template: str, kwargs: dict[str, Any], expected: str) -> None:
60+
assert path_template(template, **kwargs) == expected
61+
62+
63+
def test_missing_kwarg_raises_key_error() -> None:
64+
with pytest.raises(KeyError, match="org_id"):
65+
path_template("/v1/{org_id}")
66+
67+
68+
@pytest.mark.parametrize(
69+
"template, kwargs",
70+
[
71+
("{a}/path", dict(a=".")),
72+
("{a}/path", dict(a="..")),
73+
("/v1/{a}", dict(a=".")),
74+
("/v1/{a}", dict(a="..")),
75+
("/v1/{a}/path", dict(a=".")),
76+
("/v1/{a}/path", dict(a="..")),
77+
("/v1/{a}{b}", dict(a=".", b=".")), # adjacent vars → ".."
78+
("/v1/{a}.", dict(a=".")), # var + static → ".."
79+
("/v1/{a}{b}", dict(a="", b=".")), # empty + dot → "."
80+
("/v1/%2e/{x}", dict(x="ok")), # encoded dot in static text
81+
("/v1/%2e./{x}", dict(x="ok")), # mixed encoded ".." in static
82+
("/v1/.%2E/{x}", dict(x="ok")), # mixed encoded ".." in static
83+
("/v1/{v}?q=1", dict(v="..")),
84+
("/v1/{v}#frag", dict(v="..")),
85+
],
86+
)
87+
def test_dot_segment_rejected(template: str, kwargs: dict[str, Any]) -> None:
88+
with pytest.raises(ValueError, match="dot-segment"):
89+
path_template(template, **kwargs)

0 commit comments

Comments
 (0)