diff --git a/snuba/admin/clickhouse/copy_tables.py b/snuba/admin/clickhouse/copy_tables.py index 52cdd90bb7..0ba6bccfea 100644 --- a/snuba/admin/clickhouse/copy_tables.py +++ b/snuba/admin/clickhouse/copy_tables.py @@ -1,3 +1,4 @@ +import re from dataclasses import dataclass from typing import MutableMapping, Optional, Sequence, Tuple, TypedDict @@ -140,6 +141,12 @@ def copy_tables( if skip_on_cluster: cluster_name = None elif cluster_name_override: + # Validate cluster_name_override to prevent SQL injection + # Only allow alphanumeric characters, underscores, and hyphens + if not re.match(r"^[a-zA-Z0-9_-]+$", cluster_name_override): + raise ValueError( + "Invalid cluster name: only alphanumeric characters, underscores, and hyphens are allowed" + ) cluster_name = cluster_name_override elif not cluster.is_single_node(): cluster_name = storage.get_cluster().get_clickhouse_cluster_name() diff --git a/tests/admin/clickhouse/test_copy_tables.py b/tests/admin/clickhouse/test_copy_tables.py index ac696e375c..f16febdf8e 100644 --- a/tests/admin/clickhouse/test_copy_tables.py +++ b/tests/admin/clickhouse/test_copy_tables.py @@ -243,6 +243,39 @@ def test_copy_tables_cluster_name_override() -> None: assert result["cluster_name"] == override +@pytest.mark.redis_db +@pytest.mark.custom_clickhouse_db +def test_copy_tables_cluster_name_override_sql_injection_prevention() -> None: + """ + Test that cluster_name_override rejects invalid characters to prevent SQL injection. + Only alphanumeric characters, underscores, and hyphens should be allowed. + """ + run_migrations() + host = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1") + + # Test various SQL injection attempts + invalid_cluster_names = [ + "'; DROP TABLE users; --", + "cluster'; DROP TABLE users; --", + "cluster' OR '1'='1", + 'cluster"; DROP TABLE users; --', + "cluster name with spaces", + "cluster;name", + "cluster'name", + "cluster(name)", + "cluster.name", + ] + + for invalid_name in invalid_cluster_names: + with pytest.raises(ValueError, match="Invalid cluster name"): + copy_tables( + source_host=host, + storage_name="outcomes_raw", + dry_run=True, + cluster_name_override=invalid_name, + ) + + @pytest.mark.redis_db @pytest.mark.custom_clickhouse_db def test_verify_tables_on_replicas() -> None: