Skip to content

Commit 9a73434

Browse files
committed
Enforce recursion limit and more tidying
1 parent f6909ca commit 9a73434

File tree

11 files changed

+91
-99
lines changed

11 files changed

+91
-99
lines changed

jsonpath/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""JSONPath, JSON Pointer and JSON Patch command line interface."""
2+
23
import argparse
34
import json
45
import sys
@@ -289,7 +290,6 @@ def handle_pointer_command(args: argparse.Namespace) -> None:
289290
if args.pointer is not None:
290291
pointer = args.pointer
291292
else:
292-
# TODO: is a property with a trailing newline OK?
293293
pointer = args.pointer_file.read().strip()
294294

295295
try:

jsonpath/env.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ class attributes `root_token`, `self_token` and `filter_context_token`.
124124
index. Defaults to `(2**53) - 1`.
125125
min_int_index (int): The minimum integer allowed when selecting array items by
126126
index. Defaults to `-(2**53) + 1`.
127+
max_recursion_depth (int): The maximum number of dict/objects and/or arrays/
128+
lists the recursive descent selector can visit before a
129+
`JSONPathRecursionError` is thrown.
127130
parser_class: The parser to use when parsing tokens from the lexer.
128131
root_token (str): The pattern used to select the root node in a JSON document.
129132
Defaults to `"$"`.
@@ -132,8 +135,8 @@ class attributes `root_token`, `self_token` and `filter_context_token`.
132135
union_token (str): The pattern used as the union operator. Defaults to `"|"`.
133136
"""
134137

135-
# These should be unescaped strings. `re.escape` will be called
136-
# on them automatically when compiling lexer rules.
138+
# These should be unescaped strings. `re.escape` will be called on them
139+
# automatically when compiling lexer rules.
137140
pseudo_root_token = "^"
138141
filter_context_token = "_"
139142
intersection_token = "&"
@@ -146,6 +149,7 @@ class attributes `root_token`, `self_token` and `filter_context_token`.
146149

147150
max_int_index = (2**53) - 1
148151
min_int_index = -(2**53) + 1
152+
max_recursion_depth = 100
149153

150154
# Override these to customize path tokenization and parsing.
151155
lexer_class: Type[Lexer] = Lexer
@@ -227,7 +231,6 @@ def compile(self, path: str) -> Union[JSONPath, CompoundJSONPath]: # noqa: A003
227231
"unexpected whitespace", token=stream.tokens[stream.pos - 1]
228232
)
229233

230-
# TODO: better!
231234
if stream.current().kind != TOKEN_EOF:
232235
_path = CompoundJSONPath(env=self, path=_path)
233236
while stream.current().kind != TOKEN_EOF:

jsonpath/exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ def __init__(self, *args: object, token: Token) -> None:
7777
self.token = token
7878

7979

80+
class JSONPathRecursionError(JSONPathError):
81+
"""An exception raised when the maximum recursion depth is reached.
82+
83+
Arguments:
84+
args: Arguments passed to `Exception`.
85+
token: The token that caused the error.
86+
"""
87+
88+
def __init__(self, *args: object, token: Token) -> None:
89+
super().__init__(*args)
90+
self.token = token
91+
92+
8093
class JSONPointerError(Exception):
8194
"""Base class for all JSON Pointer errors."""
8295

jsonpath/filter.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -529,11 +529,9 @@ def __str__(self) -> str:
529529
return "@" + str(self.path)[1:]
530530

531531
def evaluate(self, context: FilterContext) -> object:
532-
if isinstance(context.current, str): # TODO: refactor
533-
if self.path.empty():
534-
return context.current
535-
return NodeList()
536-
if not isinstance(context.current, (Sequence, Mapping)):
532+
if isinstance(context.current, str) or not isinstance(
533+
context.current, (Sequence, Mapping)
534+
):
537535
if self.path.empty():
538536
return context.current
539537
return NodeList()
@@ -546,11 +544,9 @@ def evaluate(self, context: FilterContext) -> object:
546544
)
547545

548546
async def evaluate_async(self, context: FilterContext) -> object:
549-
if isinstance(context.current, str): # TODO: refactor
550-
if self.path.empty():
551-
return context.current
552-
return NodeList()
553-
if not isinstance(context.current, (Sequence, Mapping)):
547+
if isinstance(context.current, str) or not isinstance(
548+
context.current, (Sequence, Mapping)
549+
):
554550
if self.path.empty():
555551
return context.current
556552
return NodeList()
@@ -660,15 +656,19 @@ def evaluate(self, context: FilterContext) -> object:
660656
try:
661657
func = context.env.function_extensions[self.name]
662658
except KeyError:
663-
return UNDEFINED # TODO: should probably raise an exception
659+
# This can only happen if the environment's function register has been
660+
# changed since the query was parsed.
661+
return UNDEFINED
664662
args = [arg.evaluate(context) for arg in self.args]
665663
return func(*self._unpack_node_lists(func, args))
666664

667665
async def evaluate_async(self, context: FilterContext) -> object:
668666
try:
669667
func = context.env.function_extensions[self.name]
670668
except KeyError:
671-
return UNDEFINED # TODO: should probably raise an exception
669+
# This can only happen if the environment's function register has been
670+
# changed since the query was parsed.
671+
return UNDEFINED
672672
args = [await arg.evaluate_async(context) for arg in self.args]
673673
return func(*self._unpack_node_lists(func, args))
674674

jsonpath/parse.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,6 @@ def parse_relative_query(self, stream: TokenStream) -> BaseExpression:
722722
def parse_singular_query_selector(
723723
self, stream: TokenStream
724724
) -> SingularQuerySelector:
725-
# TODO: optionally require root identifier
726725
token = (
727726
stream.next() if stream.current().kind == TOKEN_ROOT else stream.current()
728727
)

jsonpath/path.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,6 @@ def intersection(self, path: JSONPath) -> CompoundJSONPath:
474474
paths=self.paths + ((self.env.intersection_token, path),),
475475
)
476476

477-
# TODO: implement empty and singular for CompoundJSONPath
478-
# TODO: add a `segments` property returning segments from all paths
479-
480477

481478
T = TypeVar("T")
482479

jsonpath/segments.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import Sequence
1212
from typing import Tuple
1313

14+
from .exceptions import JSONPathRecursionError
15+
1416
if TYPE_CHECKING:
1517
from .env import JSONPathEnvironment
1618
from .match import JSONPathMatch
@@ -99,7 +101,8 @@ async def resolve_async(
99101

100102
def _visit(self, node: JSONPathMatch, depth: int = 1) -> Iterable[JSONPathMatch]:
101103
"""Depth-first, pre-order node traversal."""
102-
# TODO: check for recursion limit
104+
if depth > self.env.max_recursion_depth:
105+
raise JSONPathRecursionError("recursion limit exceeded", token=self.token)
103106

104107
yield node
105108

tests/test_compliance.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,7 @@ def test_compliance_strict(env: JSONPathEnvironment, case: Case) -> None:
7878

7979
assert case.document is not None
8080
nodes = NodeList(env.finditer(case.selector, case.document))
81-
82-
if case.results is not None:
83-
assert case.results_paths is not None
84-
assert nodes.values() in case.results
85-
assert nodes.paths() in case.results_paths
86-
else:
87-
assert case.result_paths is not None
88-
assert nodes.values() == case.result
89-
assert nodes.paths() == case.result_paths
81+
case.assert_nodes(nodes)
9082

9183

9284
@pytest.mark.parametrize("case", valid_cases(), ids=operator.attrgetter("name"))
@@ -100,15 +92,7 @@ async def coro() -> NodeList:
10092
return NodeList([node async for node in it])
10193

10294
nodes = asyncio.run(coro())
103-
104-
if case.results is not None:
105-
assert case.results_paths is not None
106-
assert nodes.values() in case.results
107-
assert nodes.paths() in case.results_paths
108-
else:
109-
assert case.result_paths is not None
110-
assert nodes.values() == case.result
111-
assert nodes.paths() == case.result_paths
95+
case.assert_nodes(nodes)
11296

11397

11498
@pytest.mark.parametrize("case", invalid_cases(), ids=operator.attrgetter("name"))

tests/test_errors.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from operator import attrgetter
2+
from typing import Any
23
from typing import List
34
from typing import NamedTuple
45

56
import pytest
67

78
from jsonpath import JSONPathEnvironment
9+
from jsonpath.exceptions import JSONPathRecursionError
810
from jsonpath.exceptions import JSONPathSyntaxError
911
from jsonpath.exceptions import JSONPathTypeError
1012

@@ -77,3 +79,29 @@ def test_filter_literals_must_be_compared(
7779
) -> None:
7880
with pytest.raises(JSONPathSyntaxError):
7981
env.compile(case.query)
82+
83+
84+
def test_recursive_data() -> None:
85+
class MockEnv(JSONPathEnvironment):
86+
nondeterministic = False
87+
88+
env = MockEnv()
89+
query = "$..a"
90+
arr: List[Any] = []
91+
data: Any = {"foo": arr}
92+
arr.append(data)
93+
94+
with pytest.raises(JSONPathRecursionError):
95+
env.findall(query, data)
96+
97+
98+
def test_low_recursion_limit() -> None:
99+
class MockEnv(JSONPathEnvironment):
100+
max_recursion_depth = 3
101+
102+
env = MockEnv()
103+
query = "$..a"
104+
data = {"foo": [{"bar": [1, 2, 3]}]}
105+
106+
with pytest.raises(JSONPathRecursionError):
107+
env.findall(query, data)

tests/test_match_function.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

0 commit comments

Comments
 (0)