diff --git a/packages/prime/src/prime_cli/commands/config.py b/packages/prime/src/prime_cli/commands/config.py index bcdd61280..47a45ac1d 100644 --- a/packages/prime/src/prime_cli/commands/config.py +++ b/packages/prime/src/prime_cli/commands/config.py @@ -16,6 +16,16 @@ # Team ID validation pattern: CUID (v1) TEAM_ID_PATTERN = re.compile(r"^c[a-z0-9]{24}$") +PROFILE_OVERRIDE_ENV_VARS = ( + "PRIME_API_KEY", + "PRIME_TEAM_ID", + "PRIME_USER_ID", + "PRIME_API_BASE_URL", + "PRIME_BASE_URL", + "PRIME_FRONTEND_URL", + "PRIME_INFERENCE_URL", + "PRIME_CONTEXT", +) def validate_team_id(team_id: str) -> bool: @@ -32,6 +42,36 @@ def validate_team_id(team_id: str) -> bool: return bool(TEAM_ID_PATTERN.match(team_id)) +def _active_profile_override_env_vars() -> list[str]: + return [name for name in PROFILE_OVERRIDE_ENV_VARS if _env_var_overrides_profile(name)] + + +def _env_var_overrides_profile(name: str) -> bool: + value = os.getenv(name) + if value is None: + return False + if name == "PRIME_TEAM_ID": + return bool(value.strip()) + if name == "PRIME_USER_ID": + return True + if name == "PRIME_CONTEXT" and Config.context_from_cli_option(): + return False + return bool(value) + + +def _require_profile_env_unset(command: str) -> None: + names = _active_profile_override_env_vars() + if not names: + return + joined = ", ".join(names) + console.print( + f"[red]Error:[/red] {joined} {'is' if len(names) == 1 else 'are'} set in your " + f"environment, so [bold]prime config {command}[/bold] cannot make a saved profile " + "active. Unset the environment override and rerun the command." + ) + raise typer.Exit(1) + + @app.command() def view() -> None: """View current configuration""" @@ -278,6 +318,7 @@ def _set_environment( env: str, ) -> None: """Set URLs for a specific environment""" + _require_profile_env_unset(f"use {env}") config = Config() # Try to load the environment (handles both built-in and custom) diff --git a/packages/prime/src/prime_cli/core/config.py b/packages/prime/src/prime_cli/core/config.py index b3ef6e80e..b3030b7ef 100644 --- a/packages/prime/src/prime_cli/core/config.py +++ b/packages/prime/src/prime_cli/core/config.py @@ -2,7 +2,7 @@ import os import re from pathlib import Path -from typing import Optional +from typing import ClassVar, Optional from pydantic import BaseModel, ConfigDict @@ -28,6 +28,15 @@ class Config: DEFAULT_FRONTEND_URL: str = "https://app.primeintellect.ai" DEFAULT_INFERENCE_URL: str = "https://api.pinference.ai/api/v1" DEFAULT_SSH_KEY_PATH: str = str(Path.home() / ".ssh" / "id_rsa") + _context_from_cli_option: ClassVar[bool] = False + + @classmethod + def set_context_from_cli_option(cls, value: bool) -> None: + cls._context_from_cli_option = value + + @classmethod + def context_from_cli_option(cls) -> bool: + return cls._context_from_cli_option def __init__(self) -> None: self.config_dir = Path.home() / ".prime" @@ -355,27 +364,30 @@ def load_environment(self, name: str, persist: bool = True) -> bool: return False def update_current_environment_file(self) -> None: - """Update the current environment's saved file with current config""" - if self.current_environment != "production": - # Only update custom environments, not the built-in production - try: - sanitized_name = self._sanitize_environment_name(self.current_environment) - env_file = self.environments_dir / f"{sanitized_name}.json" - if env_file.exists(): - env_config = { - "api_key": self.api_key, - "team_id": self.team_id, - "team_name": None if self.team_id_from_env else self.team_name, - "team_role": None if self.team_id_from_env else self.team_role, - "user_id": self.user_id, - "base_url": self.base_url, - "frontend_url": self.frontend_url, - "inference_url": self.inference_url, - } - env_file.write_text(json.dumps(env_config, indent=2)) - except ValueError: - # Skip updating if environment name is invalid - pass + """Update the active saved environment with the persisted config values.""" + if self.current_environment == "production": + return + + try: + sanitized_name = self._sanitize_environment_name(self.current_environment) + except ValueError: + return + + env_file = self.environments_dir / f"{sanitized_name}.json" + if not env_file.exists(): + return + + env_config = { + "api_key": self.config.get("api_key", ""), + "team_id": self.config.get("team_id"), + "team_name": self.config.get("team_name"), + "team_role": self.config.get("team_role"), + "user_id": self.config.get("user_id"), + "base_url": self.config.get("base_url", self.DEFAULT_BASE_URL), + "frontend_url": self.config.get("frontend_url", self.DEFAULT_FRONTEND_URL), + "inference_url": self.config.get("inference_url", self.DEFAULT_INFERENCE_URL), + } + env_file.write_text(json.dumps(env_config, indent=2)) def list_environments(self) -> list[str]: """List all saved environment names""" diff --git a/packages/prime/src/prime_cli/main.py b/packages/prime/src/prime_cli/main.py index 5bba5fc8c..684aab629 100644 --- a/packages/prime/src/prime_cli/main.py +++ b/packages/prime/src/prime_cli/main.py @@ -107,7 +107,20 @@ def callback( typer.echo(f" - {env_name}", err=True) raise typer.Exit(1) - # Set environment variable so Config instances in subcommands pick it up + previous_context = os.environ.get("PRIME_CONTEXT") + previous_context_from_cli_option = Config.context_from_cli_option() + + def restore_context() -> None: + if previous_context is None: + os.environ.pop("PRIME_CONTEXT", None) + else: + os.environ["PRIME_CONTEXT"] = previous_context + Config.set_context_from_cli_option(previous_context_from_cli_option) + + ctx.call_on_close(restore_context) + + # Set environment variable so Config instances in subcommands pick it up. + Config.set_context_from_cli_option(True) os.environ["PRIME_CONTEXT"] = context # Check for updates (only when a subcommand is being executed) diff --git a/packages/prime/tests/test_config_delete.py b/packages/prime/tests/test_config_delete.py index 4696b01eb..b0a29116a 100644 --- a/packages/prime/tests/test_config_delete.py +++ b/packages/prime/tests/test_config_delete.py @@ -18,6 +18,17 @@ def temp_home(tmp_path: Any, monkeypatch: pytest.MonkeyPatch) -> Path: monkeypatch.setenv("HOME", str(tmp_path)) monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + for name in ( + "PRIME_API_KEY", + "PRIME_TEAM_ID", + "PRIME_USER_ID", + "PRIME_API_BASE_URL", + "PRIME_BASE_URL", + "PRIME_FRONTEND_URL", + "PRIME_INFERENCE_URL", + "PRIME_CONTEXT", + ): + monkeypatch.delenv(name, raising=False) return tmp_path diff --git a/packages/prime/tests/test_config_profiles.py b/packages/prime/tests/test_config_profiles.py new file mode 100644 index 000000000..32e987d9b --- /dev/null +++ b/packages/prime/tests/test_config_profiles.py @@ -0,0 +1,151 @@ +import json +import os +from pathlib import Path +from typing import Any + +import pytest +from prime_cli.core import Config +from prime_cli.main import app +from typer.testing import CliRunner + +runner = CliRunner() + +TEST_ENV = { + "COLUMNS": "200", + "LINES": "50", + "PRIME_DISABLE_VERSION_CHECK": "1", +} + + +@pytest.fixture +def temp_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + Config.set_context_from_cli_option(False) + for name in ( + "PRIME_API_KEY", + "PRIME_TEAM_ID", + "PRIME_USER_ID", + "PRIME_API_BASE_URL", + "PRIME_BASE_URL", + "PRIME_FRONTEND_URL", + "PRIME_INFERENCE_URL", + "PRIME_CONTEXT", + ): + monkeypatch.delenv(name, raising=False) + return tmp_path + + +def _saved_profile(home: Path, name: str) -> dict[str, Any]: + return json.loads((home / ".prime" / "environments" / f"{name}.json").read_text()) + + +def test_save_use_restores_saved_team(temp_home: Path) -> None: + config = Config() + config.set_api_key("key-team-one") + config.set_team("team-one", team_name="Team One", team_role="ADMIN") + config.save_environment("team1") + + config.set_api_key("key-team-two") + config.set_team("team-two", team_name="Team Two", team_role="MEMBER") + config.save_environment("team2") + + assert config.load_environment("team1") + + reloaded = Config() + assert reloaded.api_key == "key-team-one" + assert reloaded.team_id == "team-one" + assert reloaded.team_name == "Team One" + assert reloaded.current_environment == "team1" + + +@pytest.mark.parametrize( + ("name", "value"), + [ + ("PRIME_API_KEY", "env-key"), + ("PRIME_API_KEY", " "), + ("PRIME_USER_ID", ""), + ("PRIME_CONTEXT", "team1"), + ], +) +def test_config_use_fails_when_env_override_masks_profile( + temp_home: Path, name: str, value: str +) -> None: + config = Config() + config.save_environment("team1") + config.save_environment("staging") + + result = runner.invoke( + app, + ["config", "use", "staging"], + env={**TEST_ENV, name: value}, + ) + + assert result.exit_code == 1, result.output + assert f"{name} is set in your environment" in result.output + assert "prime config use staging" in result.output + + +def test_context_option_does_not_block_config_use(temp_home: Path) -> None: + config = Config() + config.set_api_key("key-team-one") + config.save_environment("team1") + config.set_api_key("key-staging") + config.save_environment("staging") + + result = runner.invoke(app, ["--context", "team1", "config", "use", "staging"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + + reloaded = Config() + assert reloaded.current_environment == "staging" + assert reloaded.api_key == "key-staging" + assert os.getenv("PRIME_CONTEXT") is None + assert not Config.context_from_cli_option() + + +def test_active_profile_update_uses_persisted_values_not_env_overrides( + temp_home: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + config = Config() + config.set_api_key("profile-key") + config.set_team("team-one", team_name="Team One", team_role="ADMIN") + config.save_environment("profile") + assert config.load_environment("profile") + + monkeypatch.setenv("PRIME_API_KEY", "env-key") + + config.set_team("team-two", team_name="Team Two", team_role="MEMBER") + config.update_current_environment_file() + + saved = _saved_profile(temp_home, "profile") + assert saved["api_key"] == "profile-key" + assert saved["team_id"] == "team-two" + + +def test_set_api_key_syncs_active_saved_profile_after_whoami( + temp_home: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakeAPIClient: + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + def get(self, endpoint: str) -> dict[str, dict[str, str]]: + assert endpoint == "/user/whoami" + assert self.api_key == "new-key" + return {"data": {"id": "user-new"}} + + monkeypatch.setattr("prime_cli.commands.config.APIClient", FakeAPIClient) + + config = Config() + config.set_api_key("old-key") + config.set_user_id("user-old") + config.save_environment("profile") + assert config.load_environment("profile") + + result = runner.invoke(app, ["config", "set-api-key", "new-key"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + saved = _saved_profile(temp_home, "profile") + assert saved["api_key"] == "new-key" + assert saved["user_id"] == "user-new"