Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/appengine/handlers/testcase_detail/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,21 @@ class TaskLogHandler(base_handler.Handler):
@handler.get(handler.TEXT)
def get(self):
"""Serve the task log."""
testcase_id = flask.request.args.get('testcase_id')

# Verify the user is authenticated and has access to this testcase.
testcase = access.check_access_and_get_testcase(testcase_id)

task_id = flask.request.args.get('task_id')
task_name = flask.request.args.get('task_name')
testcase_id = flask.request.args.get('testcase_id')

# Validate task_name against the known set to prevent filter injection.
valid_task_names = (
testcase_status_events.TestcaseStatusInfo.TASK_EVENTS_NAMES +
testcase_status_events.TestcaseStatusInfo.CHROME_TASK_EVENTS_NAMES)
if task_name and task_name not in valid_task_names:
raise helpers.EarlyExitError('Invalid task name.', 400)

log_content = testcase_status_events.get_task_log(testcase_id, task_id,
task_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,21 @@ def _get_time_range_filter(self, days: int) -> str:
start_time = utils.utcnow() - datetime.timedelta(days=days)
return f'timestamp >= "{start_time.isoformat()}Z"'

@staticmethod
def _sanitize_filter_value(value: str) -> str:
"""Sanitize a value for use in a Cloud Logging filter string.

Escapes double quotes and backslashes to prevent filter injection."""
return value.replace('\\', '\\\\').replace('"', '\\"')

def _get_task_log_query_filter(self, task_id: str, task_name: str) -> str:
"""Returns the filter string for querying task logs."""
query = (f'jsonPayload.extras.task_id="{task_id}" AND '
f'jsonPayload.extras.testcase_id="{self._testcase_id}" AND '
f'jsonPayload.extras.task_name="{task_name}"')
safe_task_id = self._sanitize_filter_value(task_id)
safe_task_name = self._sanitize_filter_value(task_name)
safe_testcase_id = self._sanitize_filter_value(str(self._testcase_id))
query = (f'jsonPayload.extras.task_id="{safe_task_id}" AND '
f'jsonPayload.extras.testcase_id="{safe_testcase_id}" AND '
f'jsonPayload.extras.task_name="{safe_task_name}"')
query += f' AND {self._get_time_range_filter(days=31)}'
return query

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,33 @@ def test_unreproducible_get(self):
self.assertDictContainsSubset({
'lines': [show.Line(1, 'crash_stacktrace', False)]
}, result['last_tested_crash_stacktrace'])


class TaskLogHandlerValidationTest(unittest.TestCase):
"""Test that TaskLogHandler validates task_name against the known set."""

def test_valid_task_names(self):
"""Verify the known valid task names."""
from handlers.testcase_detail import testcase_status_events
valid_names = (
testcase_status_events.TestcaseStatusInfo.TASK_EVENTS_NAMES +
testcase_status_events.TestcaseStatusInfo.CHROME_TASK_EVENTS_NAMES)

self.assertIn('analyze', valid_names)
self.assertIn('minimize', valid_names)
self.assertIn('progression', valid_names)
self.assertIn('regression', valid_names)
self.assertIn('variant', valid_names)
self.assertIn('blame', valid_names)
self.assertIn('impact', valid_names)

def test_injection_payloads_rejected(self):
"""Verify that injection payloads are not valid task names."""
from handlers.testcase_detail import testcase_status_events
valid_names = (
testcase_status_events.TestcaseStatusInfo.TASK_EVENTS_NAMES +
testcase_status_events.TestcaseStatusInfo.CHROME_TASK_EVENTS_NAMES)

self.assertNotIn('analyze" OR "1"="1', valid_names)
self.assertNotIn('evil_task', valid_names)
self.assertNotIn('', valid_names)
Original file line number Diff line number Diff line change
Expand Up @@ -834,3 +834,34 @@ def test_get_task_log_api_call(self):
self.assertLess(
result.find('"payload": "log1"'), result.find('"payload": "log2"'))
self.assertEqual(result.count('\n'), 5)


class SanitizeFilterValueTest(unittest.TestCase):
"""Test _sanitize_filter_value for Cloud Logging filter injection."""

def test_normal_value(self):
"""Test that normal values pass through unchanged."""
self.assertEqual(
testcase_status_events.TestcaseEventHistory._sanitize_filter_value(
'analyze'),
'analyze')

def test_double_quote_escaped(self):
"""Test that double quotes are escaped to prevent filter injection."""
self.assertEqual(
testcase_status_events.TestcaseEventHistory._sanitize_filter_value(
'analyze" OR resource.type="global'),
'analyze\\" OR resource.type=\\"global')

def test_backslash_escaped(self):
"""Test that backslashes are escaped."""
self.assertEqual(
testcase_status_events.TestcaseEventHistory._sanitize_filter_value(
'test\\value'),
'test\\\\value')

def test_empty_string(self):
"""Test empty string."""
self.assertEqual(
testcase_status_events.TestcaseEventHistory._sanitize_filter_value(''),
'')