Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion memoryos-playground/memdemo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import os
import json
import re
import shutil
from datetime import datetime
import secrets
Expand All @@ -20,6 +21,33 @@
# Global memoryos instance (in production, you'd use proper session management)
memory_systems = {}

# Regex for safe identifiers: alphanumeric, hyphens, underscores, dots (but not
# ".." or lone "."), and no path separators.
_SAFE_ID_RE = re.compile(r'^[A-Za-z0-9][A-Za-z0-9._-]*$')


def _is_safe_id(value: str) -> bool:
"""Return True if *value* is a safe identifier for use in filesystem paths.

Rejects path separators, '..' components, null bytes, and empty strings.
"""
if not value:
return False
if '\x00' in value:
return False
if not _SAFE_ID_RE.match(value):
return False
# Reject any component that is exactly '.' or '..'
for part in value.replace('\\', '/').split('/'):
if part in ('.', '..'):
return False
# Belt-and-suspenders: after joining, the resolved path must stay inside
# the expected parent directory.
if '..' in value or '/' in value or '\\' in value:
return False
return True


# 删除了固定的API_KEY, BASE_URL, MODEL

# 有效邀请码列表 - 在实际部署中应该存储在数据库或加密文件中
Expand Down Expand Up @@ -73,6 +101,10 @@ def init_memory():

if not user_id or not api_key or not base_url or not model:
return jsonify({'error': 'User ID, API Key, Base URL, and Model Name are required.'}), 400

# Validate user_id to prevent path traversal (CWE-22)
if not _is_safe_id(user_id):
return jsonify({'error': 'Invalid user_id. Only alphanumeric characters, hyphens, underscores, and dots are allowed.'}), 400

assistant_id = f"assistant_{user_id}"

Expand Down Expand Up @@ -424,4 +456,4 @@ def import_conversations():
return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5019)
app.run(debug=True, host='0.0.0.0', port=5019)
128 changes: 128 additions & 0 deletions tests/test_cwe22_app_flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
PoC test for CWE-22 path traversal via user_id in Flask web app.

Tests that the /init_memory endpoint rejects user_id values containing
path traversal sequences (e.g., '..', '/', '\0').
"""
import sys
import os
import json
import types

# We test the Flask app directly via its test client - no need for the full
# Memoryos stack (which requires ML models). We only need to verify that
# the input validation rejects malicious user_id values before they reach
# the Memoryos constructor.

WORKTREE = os.environ.get('WORKTREE', os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# Mock out the heavy memoryos imports before loading the Flask app
class MockMemoryos:
"""Captures constructor args so we can inspect what user_id was passed."""
instances = []

def __init__(self, **kwargs):
self.__dict__.update(kwargs)
MockMemoryos.instances.append(self)
self.user_data_dir = os.path.join(
kwargs.get('data_storage_path', '/tmp'),
'users',
kwargs.get('user_id', '')
)
self.assistant_data_dir = os.path.join(
kwargs.get('data_storage_path', '/tmp'),
'assistants',
kwargs.get('assistant_id', '')
)


mock_memoryos_module = types.ModuleType('memoryos')
mock_memoryos_module.Memoryos = MockMemoryos

mock_utils_module = types.ModuleType('memoryos.utils')
mock_utils_module.get_timestamp = lambda: "2025-01-01 00:00:00"

sys.modules['memoryos'] = mock_memoryos_module
sys.modules['memoryos.utils'] = mock_utils_module

# Now add the app directory and import it
sys.path.insert(0, os.path.join(WORKTREE, 'memoryos-playground', 'memdemo'))
import app as flask_app

client = flask_app.app.test_client()


def post_init(user_id):
"""Helper: POST /init_memory with the given user_id."""
return client.post('/init_memory', json={
'user_id': user_id,
'api_key': 'sk-test-fake',
'base_url': 'http://localhost:9999',
'model_name': 'gpt-4o-mini',
})


# ---- Attack vectors ----

TRAVERSAL_PAYLOADS = [
'../etc/cron.d', # classic unix traversal
'..\\windows\\system32', # windows traversal
'foo/../../etc/passwd', # embedded traversal
'foo/../../../tmp/evil', # deeper traversal
'....//....//etc', # double-dot-slash bypass
'/absolute/path', # absolute path
'.', # current directory reference
'..', # parent directory reference
]

SAFE_PAYLOADS = [
'alice',
'bob_123',
'user-2024',
'CamelCaseUser',
]


def test_traversal_payloads_rejected():
"""All path-traversal payloads must be rejected with 400."""
for payload in TRAVERSAL_PAYLOADS:
resp = post_init(payload)
assert resp.status_code == 400, (
f"VULN: user_id={payload!r} was accepted (status {resp.status_code}). "
f"Response: {resp.get_data(as_text=True)}"
)
body = resp.get_json()
assert 'error' in body, f"Expected error key in response for {payload!r}"
print(f" PASS rejected: {payload!r} -> 400")


def test_safe_payloads_accepted():
"""Legitimate user_id values must still be accepted."""
for payload in SAFE_PAYLOADS:
MockMemoryos.instances.clear()
resp = post_init(payload)
body = resp.get_json()
if resp.status_code == 400:
err = body.get('error', '').lower()
assert not ('user_id' in err and 'invalid' in err), (
f"False positive: safe user_id={payload!r} was rejected: {body}"
)
print(f" PASS accepted: {payload!r} -> {resp.status_code}")


if __name__ == '__main__':
print("--- Testing path traversal payloads are rejected ---")
try:
test_traversal_payloads_rejected()
except AssertionError as e:
print(f"\nFAIL: {e}")
sys.exit(1)

print("\n--- Testing safe payloads are accepted ---")
try:
test_safe_payloads_accepted()
except AssertionError as e:
print(f"\nFAIL: {e}")
sys.exit(1)

print("\nAll tests passed!")