From 6ee6cf118bbe83d7dff3800daf070dcb1a1e6aa0 Mon Sep 17 00:00:00 2001 From: Lidang-Jiang Date: Fri, 3 Apr 2026 14:51:04 +0800 Subject: [PATCH] fix: replace hardcoded API key with secure random token The vLLM server was initialized with api_key="default", allowing unauthenticated access with a well-known credential (CWE-798). Replace with _generate_api_key() that uses secrets.token_urlsafe(32), with ART_API_KEY env var override for operators who need a pinned key. Also remove "default" fallbacks in client code. Fixes #628 Signed-off-by: Lidang-Jiang --- src/art/dev/openai_server.py | 15 ++- src/art/local/backend.py | 9 +- src/art/tinker/server.py | 2 +- src/art/tinker_native/backend.py | 2 +- tests/unit/test_api_key_generation.py | 134 ++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_api_key_generation.py diff --git a/src/art/dev/openai_server.py b/src/art/dev/openai_server.py index b3b8ab535..772701446 100644 --- a/src/art/dev/openai_server.py +++ b/src/art/dev/openai_server.py @@ -5,6 +5,19 @@ from .engine import EngineArgs +def _generate_api_key() -> str: + """Return a secure, unique API key for the vLLM server. + + Prefers the ``ART_API_KEY`` environment variable so operators can pin + a known credential. Falls back to a cryptographically random token + that is different for every server invocation. + """ + import os + import secrets + + return os.environ.get("ART_API_KEY") or secrets.token_urlsafe(32) + + def get_openai_server_config( model_name: str, base_model: str, @@ -27,7 +40,7 @@ def get_openai_server_config( lora_modules = [f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}'] server_args = ServerArgs( - api_key="default", + api_key=_generate_api_key(), lora_modules=lora_modules, return_tokens_as_token_ids=True, enable_auto_tool_choice=True, diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 43e35449b..f02e8eb93 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -422,6 +422,13 @@ async def _prepare_backend_for_training( with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) server_args["port"] = s.getsockname()[1] + + # Ensure the server and client share the same API key. + # If the caller did not supply one, generate a secure random key + # so the vLLM server is never exposed with a well-known credential. + if not server_args.get("api_key"): + server_args["api_key"] = dev._generate_api_key() + config_dict["server_args"] = server_args resolved_config = cast(dev.OpenAIServerConfig, config_dict) @@ -429,7 +436,7 @@ async def _prepare_backend_for_training( host, port = await service.start_openai_server(config=resolved_config) base_url = f"http://{host}:{port}/v1" - api_key = server_args.get("api_key") or "default" + api_key = server_args["api_key"] def done_callback(_: asyncio.Task[None]) -> None: close_proxy(self._services.pop(model.name)) diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index e7fffaf92..8b7396cfa 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -140,7 +140,7 @@ async def start(self) -> tuple[str, int]: for i in range(self.num_workers or self._default_num_workers()) ] self._task = asyncio.create_task(self._run(host, port)) - client = AsyncOpenAI(api_key="default", base_url=f"http://{host}:{port}/v1") + client = AsyncOpenAI(api_key="health-check", base_url=f"http://{host}:{port}/v1") start = time.time() while True: timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0)) diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index c1687bf7f..8f99198cf 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -213,7 +213,7 @@ async def _prepare_backend_for_training( port = server_args.get("port", raw_config.get("port")) if port is None: port = get_free_port() - api_key = server_args.get("api_key", raw_config.get("api_key")) or "default" + api_key = server_args.get("api_key", raw_config.get("api_key")) or dev._generate_api_key() if state.server_task is None: state.server_host = host diff --git a/tests/unit/test_api_key_generation.py b/tests/unit/test_api_key_generation.py new file mode 100644 index 000000000..0b4f2fa4c --- /dev/null +++ b/tests/unit/test_api_key_generation.py @@ -0,0 +1,134 @@ +"""Unit tests for secure API key generation in vLLM server configuration. + +Verifies that the hardcoded 'default' API key (CWE-798) has been replaced +with a cryptographically random token, and that the ART_API_KEY environment +variable override works correctly. + +See: https://github.com/OpenPipe/ART/issues/628 +""" + +import os + +import pytest + +from art.dev.openai_server import ( + OpenAIServerConfig, + _generate_api_key, + get_openai_server_config, +) + + +class TestGenerateApiKey: + """Tests for the ``_generate_api_key`` helper.""" + + def test_key_is_not_hardcoded_default(self) -> None: + """The generated key must never be the literal string 'default'.""" + key = _generate_api_key() + assert key != "default" + + def test_key_is_nonempty_string(self) -> None: + key = _generate_api_key() + assert isinstance(key, str) + assert len(key) > 0 + + def test_key_has_sufficient_entropy(self) -> None: + """A 32-byte urlsafe token encodes to >= 32 characters.""" + key = _generate_api_key() + assert len(key) >= 32 + + def test_keys_are_unique_across_calls(self) -> None: + """Each invocation should produce a different random key.""" + keys = {_generate_api_key() for _ in range(20)} + assert len(keys) == 20 + + def test_env_var_override(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ART_API_KEY environment variable should take precedence.""" + monkeypatch.setenv("ART_API_KEY", "my-custom-key-42") + assert _generate_api_key() == "my-custom-key-42" + + def test_env_var_empty_falls_back_to_random( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """An empty ART_API_KEY should be treated as unset.""" + monkeypatch.setenv("ART_API_KEY", "") + key = _generate_api_key() + assert key != "" + assert key != "default" + + def test_env_var_unset_falls_back_to_random( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """When ART_API_KEY is not in the environment, use random key.""" + monkeypatch.delenv("ART_API_KEY", raising=False) + key = _generate_api_key() + assert key != "default" + assert len(key) >= 32 + + +class TestGetOpenaiServerConfig: + """Tests for ``get_openai_server_config`` API key behaviour.""" + + def test_default_config_uses_random_key(self) -> None: + """Without user-supplied config, the key must not be 'default'.""" + config = get_openai_server_config( + model_name="test-model", + base_model="base-model", + log_file="/tmp/test.log", + ) + api_key = config["server_args"]["api_key"] + assert api_key != "default" + assert isinstance(api_key, str) + assert len(api_key) >= 32 + + def test_user_supplied_key_takes_precedence(self) -> None: + """A user-provided api_key in server_args should override the default.""" + user_config = OpenAIServerConfig( + server_args={"api_key": "user-provided-key-xyz"} # type: ignore[typeddict-item] + ) + config = get_openai_server_config( + model_name="test-model", + base_model="base-model", + log_file="/tmp/test.log", + config=user_config, + ) + assert config["server_args"]["api_key"] == "user-provided-key-xyz" + + def test_env_var_key_used_when_no_user_config( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """ART_API_KEY env var should be picked up when no user key is given.""" + monkeypatch.setenv("ART_API_KEY", "env-key-123") + config = get_openai_server_config( + model_name="test-model", + base_model="base-model", + log_file="/tmp/test.log", + ) + assert config["server_args"]["api_key"] == "env-key-123" + + def test_user_key_overrides_env_var( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """User-supplied key should win over ART_API_KEY env var.""" + monkeypatch.setenv("ART_API_KEY", "env-key-123") + user_config = OpenAIServerConfig( + server_args={"api_key": "user-wins"} # type: ignore[typeddict-item] + ) + config = get_openai_server_config( + model_name="test-model", + base_model="base-model", + log_file="/tmp/test.log", + config=user_config, + ) + assert config["server_args"]["api_key"] == "user-wins" + + def test_each_config_call_gets_unique_key(self) -> None: + """Successive calls should not share the same random key.""" + keys = set() + for _ in range(10): + config = get_openai_server_config( + model_name="test-model", + base_model="base-model", + log_file="/tmp/test.log", + ) + keys.add(config["server_args"]["api_key"]) + assert len(keys) == 10