Skip to content

Commit 091ba3b

Browse files
Add max_depth parameter to prevent DoS from deeply nested data
1 parent a8a850a commit 091ba3b

File tree

3 files changed

+201
-91
lines changed

3 files changed

+201
-91
lines changed

rest_framework/fields.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,22 +1692,14 @@ def _propagate_depth_to_child(self):
16921692
if hasattr(self.child, '_propagate_depth_to_child'):
16931693
self.child._propagate_depth_to_child()
16941694

1695-
def _check_data_depth(self, data, current=0):
1696-
if self._root_max_depth is not None:
1697-
if isinstance(data, (list, tuple)):
1698-
for item in data:
1699-
if isinstance(item, (list, tuple, dict)):
1700-
next_depth = current + 1
1701-
if next_depth > self._root_max_depth:
1702-
self.fail('max_depth', max_depth=self._root_max_depth)
1703-
self._check_data_depth(item, next_depth)
1704-
elif isinstance(data, dict):
1705-
for value in data.values():
1706-
if isinstance(value, (list, tuple, dict)):
1707-
next_depth = current + 1
1708-
if next_depth > self._root_max_depth:
1709-
self.fail('max_depth', max_depth=self._root_max_depth)
1710-
self._check_data_depth(value, next_depth)
1695+
def _check_data_depth(self, data, current_level):
1696+
items = data.values() if isinstance(data, dict) else data
1697+
for item in items:
1698+
if isinstance(item, (list, tuple, dict)):
1699+
next_level = current_level + 1
1700+
if next_level > self._root_max_depth:
1701+
self.fail('max_depth', max_depth=self._root_max_depth)
1702+
self._check_data_depth(item, next_level)
17111703

17121704
def get_value(self, dictionary):
17131705
if self.field_name not in dictionary:
@@ -1734,9 +1726,12 @@ def to_internal_value(self, data):
17341726
self.fail('not_a_list', input_type=type(data).__name__)
17351727
if not self.allow_empty and len(data) == 0:
17361728
self.fail('empty')
1737-
if self._root_max_depth is not None and self._current_depth > self._root_max_depth:
1738-
self.fail('max_depth', max_depth=self._root_max_depth)
1739-
self._check_data_depth(data, self._current_depth)
1729+
if self._root_max_depth is not None:
1730+
start_level = self._current_depth if self._current_depth > 0 else 1
1731+
if start_level > self._root_max_depth:
1732+
self.fail('max_depth', max_depth=self._root_max_depth)
1733+
if self.max_depth is not None:
1734+
self._check_data_depth(data, start_level)
17401735
return self.run_child_validation(data)
17411736

17421737
def to_representation(self, data):
@@ -1789,7 +1784,7 @@ def __init__(self, **kwargs):
17891784

17901785
def bind(self, field_name, parent):
17911786
super().bind(field_name, parent)
1792-
if hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
1787+
if self.max_depth is None and hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
17931788
self._root_max_depth = parent._root_max_depth
17941789
self._current_depth = parent._current_depth + 1
17951790
self._propagate_depth_to_child()
@@ -1801,22 +1796,14 @@ def _propagate_depth_to_child(self):
18011796
if hasattr(self.child, '_propagate_depth_to_child'):
18021797
self.child._propagate_depth_to_child()
18031798

1804-
def _check_data_depth(self, data, current=0):
1805-
if self._root_max_depth is not None:
1806-
if isinstance(data, dict):
1807-
for value in data.values():
1808-
if isinstance(value, (list, tuple, dict)):
1809-
next_depth = current + 1
1810-
if next_depth > self._root_max_depth:
1811-
self.fail('max_depth', max_depth=self._root_max_depth)
1812-
self._check_data_depth(value, next_depth)
1813-
elif isinstance(data, (list, tuple)):
1814-
for item in data:
1815-
if isinstance(item, (list, tuple, dict)):
1816-
next_depth = current + 1
1817-
if next_depth > self._root_max_depth:
1818-
self.fail('max_depth', max_depth=self._root_max_depth)
1819-
self._check_data_depth(item, next_depth)
1799+
def _check_data_depth(self, data, current_level):
1800+
items = data.values() if isinstance(data, dict) else data
1801+
for item in items:
1802+
if isinstance(item, (list, tuple, dict)):
1803+
next_level = current_level + 1
1804+
if next_level > self._root_max_depth:
1805+
self.fail('max_depth', max_depth=self._root_max_depth)
1806+
self._check_data_depth(item, next_level)
18201807

18211808
def get_value(self, dictionary):
18221809
# We override the default field access in order to support
@@ -1835,9 +1822,12 @@ def to_internal_value(self, data):
18351822
self.fail('not_a_dict', input_type=type(data).__name__)
18361823
if not self.allow_empty and len(data) == 0:
18371824
self.fail('empty')
1838-
if self._root_max_depth is not None and self._current_depth > self._root_max_depth:
1839-
self.fail('max_depth', max_depth=self._root_max_depth)
1840-
self._check_data_depth(data, self._current_depth)
1825+
if self._root_max_depth is not None:
1826+
start_level = self._current_depth if self._current_depth > 0 else 1
1827+
if start_level > self._root_max_depth:
1828+
self.fail('max_depth', max_depth=self._root_max_depth)
1829+
if self.max_depth is not None:
1830+
self._check_data_depth(data, start_level)
18411831
return self.run_child_validation(data)
18421832

18431833
def to_representation(self, value):

rest_framework/serializers.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@
7676
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
7777
'label', 'help_text', 'style', 'error_messages', 'allow_empty',
7878
'instance', 'data', 'partial', 'context', 'allow_null',
79-
'max_length', 'min_length'
79+
'max_length', 'min_length', 'max_depth'
8080
)
81-
LIST_SERIALIZER_KWARGS_REMOVE = ('allow_empty', 'min_length', 'max_length')
81+
LIST_SERIALIZER_KWARGS_REMOVE = ('allow_empty', 'min_length', 'max_length', 'max_depth')
8282

8383
ALL_FIELDS = '__all__'
8484

@@ -111,20 +111,25 @@ class BaseSerializer(Field):
111111
.data - Available.
112112
"""
113113

114+
default_error_messages = {
115+
'max_depth': _('Nesting depth exceeds maximum allowed depth of {max_depth}.')
116+
}
117+
114118
def __init__(self, instance=None, data=empty, **kwargs):
115119
self.instance = instance
116120
if data is not empty:
117121
self.initial_data = data
118122
self.partial = kwargs.pop('partial', False)
119123
self._context = kwargs.pop('context', {})
120124
kwargs.pop('many', None)
125+
self.max_depth = kwargs.pop('max_depth', None)
121126
super().__init__(**kwargs)
122127
self._current_depth = 0
123-
self._root_max_depth = None
128+
self._root_max_depth = self.max_depth
124129

125130
def bind(self, field_name, parent):
126131
super().bind(field_name, parent)
127-
if hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
132+
if self.max_depth is None and hasattr(parent, '_root_max_depth') and parent._root_max_depth is not None:
128133
self._root_max_depth = parent._root_max_depth
129134
self._current_depth = parent._current_depth + 1
130135

@@ -137,6 +142,32 @@ def _propagate_depth_to_child(self):
137142
if hasattr(field, '_propagate_depth_to_child'):
138143
field._propagate_depth_to_child()
139144

145+
def _check_data_depth(self, data, current_level):
146+
if isinstance(data, dict):
147+
for value in data.values():
148+
if isinstance(value, (list, tuple, dict)):
149+
next_level = current_level + 1
150+
if next_level > self._root_max_depth:
151+
message = self.error_messages['max_depth'].format(
152+
max_depth=self._root_max_depth
153+
)
154+
raise ValidationError({
155+
api_settings.NON_FIELD_ERRORS_KEY: [message]
156+
}, code='max_depth')
157+
self._check_data_depth(value, next_level)
158+
elif isinstance(data, (list, tuple)):
159+
for item in data:
160+
if isinstance(item, (list, tuple, dict)):
161+
next_level = current_level + 1
162+
if next_level > self._root_max_depth:
163+
message = self.error_messages['max_depth'].format(
164+
max_depth=self._root_max_depth
165+
)
166+
raise ValidationError({
167+
api_settings.NON_FIELD_ERRORS_KEY: [message]
168+
}, code='max_depth')
169+
self._check_data_depth(item, next_level)
170+
140171
def __new__(cls, *args, **kwargs):
141172
# We override this method in order to automatically create
142173
# `ListSerializer` classes instead when `many=True` is set.
@@ -390,7 +421,8 @@ def fields(self):
390421
fields = BindingDict(self)
391422
for key, value in self.get_fields().items():
392423
fields[key] = value
393-
self._propagate_depth_to_child()
424+
if self._root_max_depth is not None:
425+
self._propagate_depth_to_child()
394426
return fields
395427

396428
@property
@@ -507,6 +539,9 @@ def to_internal_value(self, data):
507539
raise ValidationError({
508540
api_settings.NON_FIELD_ERRORS_KEY: [message]
509541
}, code='invalid')
542+
if self._root_max_depth is not None and self.max_depth is not None:
543+
start_level = self._current_depth
544+
self._check_data_depth(data, start_level)
510545

511546
ret = {}
512547
errors = {}
@@ -672,6 +707,32 @@ def run_child_validation(self, data):
672707
"""
673708
return self.child.run_validation(data)
674709

710+
def _check_data_depth(self, data, current_level):
711+
if isinstance(data, (list, tuple)):
712+
for item in data:
713+
if isinstance(item, (list, tuple, dict)):
714+
next_level = current_level + 1
715+
if next_level > self._root_max_depth:
716+
message = self.error_messages['max_depth'].format(
717+
max_depth=self._root_max_depth
718+
)
719+
raise ValidationError({
720+
api_settings.NON_FIELD_ERRORS_KEY: [message]
721+
}, code='max_depth')
722+
self._check_data_depth(item, next_level)
723+
elif isinstance(data, dict):
724+
for value in data.values():
725+
if isinstance(value, (list, tuple, dict)):
726+
next_level = current_level + 1
727+
if next_level > self._root_max_depth:
728+
message = self.error_messages['max_depth'].format(
729+
max_depth=self._root_max_depth
730+
)
731+
raise ValidationError({
732+
api_settings.NON_FIELD_ERRORS_KEY: [message]
733+
}, code='max_depth')
734+
self._check_data_depth(value, next_level)
735+
675736
def to_internal_value(self, data):
676737
"""
677738
List of dicts of native values <- List of dicts of primitive datatypes.
@@ -687,6 +748,10 @@ def to_internal_value(self, data):
687748
api_settings.NON_FIELD_ERRORS_KEY: [message]
688749
}, code='not_a_list')
689750

751+
if self._root_max_depth is not None and self.max_depth is not None:
752+
start_level = self._current_depth
753+
self._check_data_depth(data, start_level)
754+
690755
if not self.allow_empty and len(data) == 0:
691756
message = self.error_messages['empty']
692757
raise ValidationError({

0 commit comments

Comments
 (0)